#include "../includeF/sfx_def.fi"

module sfx_sheba_coare

    ! modules used
    ! --------------------------------------------------------------------------------
#ifdef SFX_CHECK_NAN
    use sfx_common
#endif
    use sfx_data
    use sfx_surface
    use sfx_sheba_coare_param

#if defined(INCLUDE_CXX)
    use iso_c_binding, only: C_LOC, C_PTR, C_INT, C_FLOAT
    use C_FUNC
#endif
    ! --------------------------------------------------------------------------------

    ! directives list
    ! --------------------------------------------------------------------------------
    implicit none
    private
    ! --------------------------------------------------------------------------------

    ! public interface
    ! --------------------------------------------------------------------------------
    public :: get_surface_fluxes
    public :: get_surface_fluxes_vec
    public :: get_psi_stable
    public :: get_psi_a
    public :: get_psi_convection
    public :: get_psi_BD
    integer z0m_id
    integer z0t_id
    ! --------------------------------------------------------------------------------
    ! --------------------------------------------------------------------------------
    type, public :: numericsType
        integer :: maxiters_charnock = 10      !< maximum (actual) number of iterations in charnock roughness
        integer :: maxiters_convection = 10      !< maximum (actual) number of iterations in charnock roughness
    end type

    ! --------------------------------------------------------------------------------

#if defined(INCLUDE_CXX)
    type, BIND(C), public :: sfx_sheba_coare_param_C 
        real(C_FLOAT) :: kappa
        real(C_FLOAT) :: Pr_t_0_inv
        real(C_FLOAT) :: Pr_t_inf_inv

        real(C_FLOAT) :: alpha_m
        real(C_FLOAT) :: alpha_h
        real(C_FLOAT) :: gamma
        real(C_FLOAT) :: zeta_a
        real(C_FLOAT) ::  a_m
        real(C_FLOAT) ::  b_m
        real(C_FLOAT) ::  a_h
        real(C_FLOAT) ::  b_h
        real(C_FLOAT) ::  c_h
        real(C_FLOAT) ::  beta_m
        real(C_FLOAT) ::  beta_h
    end type

    type, BIND(C), public :: sfx_sheba_coare_numericsType_C 
        integer(C_INT) :: maxiters_convection
        integer(C_INT) :: maxiters_charnock 
    end type

    INTERFACE
        SUBROUTINE c_sheba_coare_compute_flux(sfx, meteo, model_param, surface_param, numerics, constants, grid_size) BIND(C, & 
            name="c_sheba_coare_compute_flux")
            use sfx_data
            use, intrinsic :: ISO_C_BINDING, ONLY: C_INT, C_PTR
            Import :: sfx_sheba_coare_param_C, sfx_sheba_coare_numericsType_C
            implicit none
            integer(C_INT) :: grid_size
            type(C_PTR), value :: sfx
            type(C_PTR), value :: meteo
            type(sfx_sheba_coare_param_C) :: model_param
            type(sfx_surface_sheba_coare_param) :: surface_param
            type(sfx_sheba_coare_numericsType_C) :: numerics
            type(sfx_phys_constants) :: constants
        END SUBROUTINE c_sheba_coare_compute_flux

    END INTERFACE
#endif 

contains

    ! --------------------------------------------------------------------------------
#if defined(INCLUDE_CXX)
    subroutine set_c_struct_sfx_sheba_coare_param_values(sfx_model_param)
        type (sfx_sheba_coare_param_C), intent(inout) :: sfx_model_param
        sfx_model_param%kappa = kappa
        sfx_model_param%Pr_t_0_inv = Pr_t_0_inv
        sfx_model_param%Pr_t_inf_inv = Pr_t_inf_inv

        sfx_model_param%alpha_m = alpha_m
        sfx_model_param%alpha_h = alpha_h
        sfx_model_param%gamma = gamma
        sfx_model_param%zeta_a = zeta_a
        sfx_model_param%a_m = a_m
        sfx_model_param%b_m = b_m
        sfx_model_param%a_h = a_h
        sfx_model_param%b_h = b_h
        sfx_model_param%c_h = c_h
        sfx_model_param%c3 = beta_m
        sfx_model_param%c4 = beta_h
    end subroutine set_c_struct_sfx_sheba_coare_param_values
#endif

    ! --------------------------------------------------------------------------------
    subroutine get_surface_fluxes_vec(sfx, meteo, numerics, n)
        !< @brief surface flux calculation for array data
        !< @details contains C/C++ & CUDA interface
        ! ----------------------------------------------------------------------------
        type (sfxDataVecType), intent(inout) :: sfx

        type (meteoDataVecType), intent(in) :: meteo
        type (numericsType), intent(in) :: numerics
        integer, intent(in) :: n
        ! ----------------------------------------------------------------------------

        ! --- local variables
        type (meteoDataType)  meteo_cell
        type (sfxDataType) sfx_cell
        integer i
        ! ----------------------------------------------------------------------------
#if defined(INCLUDE_CXX)
        type (meteoDataVecTypeC), target :: meteo_c         !< meteorological data (input)
        type (sfxDataVecTypeC), target :: sfx_c             !< surface flux data (output)
        type(C_PTR) :: meteo_c_ptr, sfx_c_ptr
        type (sfx_sheba_coare_param_C) :: model_param
        type (sfx_surface_param) :: surface_param
        type (sfx_sheba_coare_numericsType_C) :: numerics_c
        type (sfx_phys_constants) :: phys_constants

        numerics_c%maxiters_convection = numerics%maxiters_convection
        numerics_c%maxiters_charnock = numerics%maxiters_charnock

        phys_constants%Pr_m = Pr_m;
        phys_constants%nu_air = nu_air;
        phys_constants%g = g;

        call set_c_struct_sfx_sheba_coare_param_values(model_param)
        call set_c_struct_sfx_surface_param_values(surface_param)
        call set_meteo_vec_c(meteo, meteo_c)
        call set_sfx_vec_c(sfx, sfx_c)
        meteo_c_ptr = C_LOC(meteo_c)
        sfx_c_ptr   = C_LOC(sfx_c)

        call c_sheba_coare_compute_flux(sfx_c_ptr, meteo_c_ptr, model_param, surface_param, numerics_c, phys_constants, n)
#else
        do i = 1, n

            meteo_cell = meteoDataType(&
                    h = meteo%h(i), &
                    U = meteo%U(i), dT = meteo%dT(i), Tsemi = meteo%Tsemi(i), dQ = meteo%dQ(i), &
                    z0_m = meteo%z0_m(i), depth=meteo%depth(i), lai=meteo%lai(i), surface_type=meteo%surface_type(i))

            call get_surface_fluxes(sfx_cell, meteo_cell, numerics)

            call push_sfx_data(sfx, sfx_cell, i)
        end do
#endif

    end subroutine get_surface_fluxes_vec
    ! --------------------------------------------------------------------------------

    ! --------------------------------------------------------------------------------
    subroutine get_surface_fluxes(sfx, meteo, numerics)
        !< @brief surface flux calculation for single cell
        !< @details contains C/C++ interface
        ! ----------------------------------------------------------------------------
#ifdef SFX_CHECK_NAN
        use ieee_arithmetic
#endif

        type (sfxDataType), intent(out) :: sfx

        type (meteoDataType), intent(in) :: meteo
        type (numericsType), intent(in) :: numerics
        ! ----------------------------------------------------------------------------
        ! --- meteo derived datatype name shadowing
        ! ----------------------------------------------------------------------------
        real :: h       !< constant flux layer height [m]
        real :: U       !< abs(wind speed) at 'h' [m/s]
        real :: dT      !< difference between potential temperature at 'h' and at surface [K]
        real :: Tsemi   !< semi-sum of potential temperature at 'h' and at surface [K]
        real :: dQ      !< difference between humidity at 'h' and at surface [g/g]
        real :: z0_m    !< surface aerodynamic roughness (should be < 0 for water bodies surface)
        !real :: hpbl
        real :: depth   
        real :: lai
        integer :: surface_type
        ! ----------------------------------------------------------------------------

        ! --- local variables
        ! ----------------------------------------------------------------------------
        real z0_t               !< thermal roughness [m]
        real B                  !< = ln(z0_m / z0_t) [n/d]
        real h0_m, h0_t         !< = h / z0_m, h / z0_h [n/d]

        real u_dyn0             !< dynamic velocity in neutral conditions [m/s]
        real Re                 !< roughness Reynolds number = u_dyn0 * z0_m / nu [n/d]

        real zeta               !< = z/L [n/d]
        real Rib                !< bulk Richardson number

        real zeta_conv_lim      !< z/L critical value for matching free convection limit [n/d]
        real Rib_conv_lim       !< Ri-bulk critical value for matching free convection limit [n/d]

        real f_m_conv_lim       !< stability function (momentum) value in free convection limit [n/d]
        real f_h_conv_lim       !< stability function (heat) value in free convection limit [n/d]

        real psi_m, psi_h       !< universal functions (momentum) & (heat) [n/d]
        real psi0_m, psi0_h       !< universal functions (momentum) & (heat) [n/d]
        real z0_m1
!        real psi_m_BD, psi_h_BD       !< universal functions (momentum) & (heat) [n/d]
!        real psi0_m_BD, psi0_h_BD       !< universal functions (momentum) & (heat) [n/d]
!        real psi_m_conv, psi_h_conv       !< universal functions (momentum) & (heat) [n/d]
!        real psi0_m_conv, psi0_h_conv       !< universal functions (momentum) & (heat) [n/d]

        real Udyn, Tdyn, Qdyn   !< dynamic scales

        real phi_m, phi_h       !< stability functions (momentum) & (heat) [n/d]

        real Km                 !< eddy viscosity coeff. at h [m^2/s]
        real Pr_t_inv           !< invese Prandt number [n/d]

        real Cm, Ct             !< transfer coeff. for (momentum) & (heat) [n/d]

        !integer surface_type    !< surface type = (ocean || land)

        real c_wdyn

#ifdef SFX_CHECK_NAN
        real NaN
#endif
        ! ----------------------------------------------------------------------------

#ifdef SFX_CHECK_NAN
        ! --- checking if arguments are finite
        if (.not.(is_finite(meteo%U).and.is_finite(meteo%Tsemi).and.is_finite(meteo%dT).and.is_finite(meteo%dQ) &
                .and.is_finite(meteo%z0_m).and.is_finite(meteo%h))) then

            NaN = ieee_value(0.0, ieee_quiet_nan)   ! setting NaN
            sfx = sfxDataType(zeta = NaN, Rib = NaN, &
                    Re = NaN, B = NaN, z0_m = NaN, z0_t = NaN, &
                    Rib_conv_lim = NaN, &
                    Cm = NaN, Ct = NaN, Km = NaN, Pr_t_inv = NaN)
                    !Cm = NaN, Ct = NaN, Km = NaN, Pr_t_inv = NaN, c_wdyn = NaN)
            return
        end if
#endif

        ! --- shadowing names for clarity
        U = meteo%U
        Tsemi = meteo%Tsemi
        dT = meteo%dT
        dQ = meteo%dQ
        h = meteo%h
        z0_m1 = meteo%z0_m
        depth = meteo%depth
        lai = meteo%lai
        surface_type=meteo%surface_type
        !hpbl = meteo%hpbl

        call get_dynamic_roughness_definition(surface_type, ocean_z0m_id, land_z0m_id, lake_z0m_id, snow_z0m_id, &
        forest_z0m_id, usersf_z0m_id, ice_z0m_id, z0m_id)
        
        
        call get_dynamic_roughness_all(z0_m, u_dyn0, U, depth, h, numerics%maxiters_charnock, z0_m1, z0m_id)
        
        call get_thermal_roughness_definition(surface_type, ocean_z0t_id, land_z0t_id, lake_z0t_id, snow_z0t_id, &
        forest_z0t_id, usersf_z0t_id, ice_z0t_id, z0t_id)
        
        
        Re = u_dyn0 * z0_m / nu_air

        call get_thermal_roughness_all(z0_t, B, z0_m, Re, u_dyn0, lai, z0t_id)
        ! --- define relative height
            h0_m = h / z0_m
        ! --- define relative height [thermal]
        h0_t = h / z0_t
      
        ! --- define Ri-bulk
        Rib = (g / Tsemi) * h * (dT + 0.61e0 * Tsemi * dQ) / U**2

        ! --- define free convection transition zeta = z/L value
        call get_convection_lim(zeta_conv_lim, Rib_conv_lim, f_m_conv_lim, f_h_conv_lim, &
                h0_m, h0_t, B)

        ! --- get the fluxes
        ! ----------------------------------------------------------------------------
        if (Rib > 0.0) then
            ! --- stable stratification block

            !   --- restrict bulk Ri value
            !   *: note that this value is written to output
!            Rib = min(Rib, Rib_max)

            call get_zeta_stable(zeta, Rib, h, z0_m, z0_t)

            call get_psi_stable(psi_m, psi_h, zeta, zeta)
            call get_psi_stable(psi0_m, psi0_h, zeta * z0_m / h, zeta * z0_t / h)

            phi_m = 1.0 + (a_m * zeta * (1.0 + zeta)**(1.0 / 3.0)) / (1.0 + b_m * zeta)
            phi_h = 1.0 + (a_h * zeta + b_h * zeta * zeta) / (1.0 + c_h * zeta + zeta * zeta)

            Udyn = kappa * U / (log(h / z0_m) - (psi_m - psi0_m))
            Tdyn = kappa * dT * Pr_t_0_inv / (log(h / z0_t) - (psi_h - psi0_h))

        else if (Rib <= -0.001)then

            call get_dynamic_scales(Udyn, Tdyn, Qdyn, zeta, psi_m, psi_h, psi0_m, psi0_h,&
            U, Tsemi, dT, dQ, h, z0_m, z0_t, (g / Tsemi), numerics%maxiters_convection)

            call get_phi_a(phi_m,phi_h,zeta)
!!            call get_phi_a2(phi_m,phi_h,zeta)
!!            call get_phi_a3(phi_m,phi_h,zeta)
!!         print *, zeta,psi_m,psi_h,phi_m,phi_h

            psi_m = (log(h / z0_m) - (psi_m - psi0_m))
            psi_h = (log(h / z0_t) - (psi_h - psi0_h))

!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
!!!!!!!!!!non-iterative version below is not debugged yet!!
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
!            call get_zeta_conv(zeta,Rib,h,z0_m,z0_t)
!
!            call get_psi_a(psi_m, psi_h, zeta,zeta)
!            call get_psi_a(psi0_m, psi0_h, zeta * z0_m / h, zeta * z0_t / h)
!
!            Udyn = kappa * U / (log(h / z0_m) - (psi_m - psi0_m))
!            Tdyn = kappa * dT * Pr_t_0_inv / (log(h / z0_t) - (psi_h - psi0_h))
!
!            call get_phi_a(phi_m,phi_h,zeta)
!print *, zeta,psi_m,psi_h,phi_m,phi_h
!
!            psi_m = (log(h / z0_m) - (psi_m - psi0_m))
!            psi_h = (log(h / z0_t) - (psi_h - psi0_h))

        else
            ! --- nearly neutral [-0.001, 0] block

            call get_psi_neutral(psi_m, psi_h, h0_m, h0_t, B)

            zeta = 0.0
            phi_m = 1.0
            phi_h = 1.0 / Pr_t_0_inv

            Udyn = kappa * U / log(h / z0_m)
            Tdyn = kappa * dT * Pr_t_0_inv / log(h / z0_t)

        end if
        ! ----------------------------------------------------------------------------

        ! --- define transfer coeff. (momentum) & (heat)
        if(Rib > 0)then
            Cm = 0.0
            if (U > 0.0) then
            Cm = Udyn / U
            end if
            Ct = 0.0
            if (abs(dT) > 0.0) then
            Ct = Tdyn / dT
            end if
        else
            Cm = kappa / psi_m
            Ct = kappa / psi_h
        end if

        ! --- define eddy viscosity & inverse Prandtl number
        Km = kappa * Cm * U * h / phi_m
        Pr_t_inv = phi_m / phi_h

        ! --- setting output
        sfx = sfxDataType(zeta = zeta, Rib = Rib, &
            Re = Re, B = B, z0_m = z0_m, z0_t = z0_t, &
            Rib_conv_lim = Rib_conv_lim, &
            Cm = Cm, Ct = Ct, Km = Km, Pr_t_inv = Pr_t_inv)
            !    Cm = Cm, Ct = Ct, Km = Km, Pr_t_inv = Pr_t_inv, c_wdyn = 0)

    end subroutine get_surface_fluxes
    ! --------------------------------------------------------------------------------



    ! universal functions
    ! --------------------------------------------------------------------------------
    subroutine get_psi_neutral(psi_m, psi_h, h0_m, h0_t, B)
        !< @brief universal functions (momentum) & (heat): neutral case
        ! ----------------------------------------------------------------------------
        real, intent(out) :: psi_m, psi_h   !< universal functions

        real, intent(in) :: h0_m, h0_t      !< = z/z0_m, z/z0_h
        real, intent(in) :: B               !< = log(z0_m / z0_h)
        ! ----------------------------------------------------------------------------

        psi_m = log(h0_m)
        psi_h = log(h0_t) / Pr_t_0_inv
        !*: this looks redundant z0_t = z0_m in case |B| ~ 0
        if (abs(B) < 1.0e-10) psi_h = psi_m / Pr_t_0_inv

    end subroutine


    subroutine get_zeta_stable(zeta, Rib, h, z0_m, z0_t)
    real,intent(out) :: zeta
    real,intent(in) :: Rib, h, z0_m, z0_t

    real :: Ribl, C1, A1, A2, lne, lnet
    real :: psi_m, psi_h, psi0_m, psi0_h

        Ribl = (Rib * Pr_t_0_inv) * (1 - z0_t / h) / ((1 - z0_m / h)**2)

        call get_psi_stable(psi_m, psi_h, zeta_a, zeta_a)
        call get_psi_stable(psi0_m, psi0_h, zeta_a * z0_m / h,  zeta_a * z0_t / h)

        lne = log(h/z0_m)
        lnet = log(h/z0_t)
        C1 = (lne**2)/lnet
        A1 = ((lne - psi_m + psi0_m)**(2*(gamma-1))) &
&           / ((zeta_a**(gamma-1))*((lnet-(psi_h-psi0_h)*Pr_t_0_inv)**(gamma-1)))
        A2 = ((lne - psi_m + psi0_m)**2) / (lnet-(psi_h-psi0_h)*Pr_t_0_inv) - C1

        zeta = C1 * Ribl + A1 * A2 * (Ribl**gamma)

    end subroutine get_zeta_stable



    subroutine get_psi_stable(psi_m, psi_h, zeta_m, zeta_h)
        !< @brief universal functions (momentum) & (heat): neutral case
        ! ----------------------------------------------------------------------------
        real, intent(out) :: psi_m, psi_h   !< universal functions

        real, intent(in) :: zeta_m, zeta_h  !< = z/L
        ! ----------------------------------------------------------------------------

        ! --- local variables
        real :: x_m, x_h
        real :: q_m, q_h
        ! ----------------------------------------------------------------------------


            q_m = ((1.0 - b_m) / b_m)**(1.0 / 3.0)
            x_m = (1.0 + zeta_m)**(1.0 / 3.0)

            psi_m = -3.0 * (a_m / b_m) * (x_m - 1.0) + 0.5 * (a_m / b_m) * q_m * (&
                    2.0 * log((x_m + q_m) / (1.0 + q_m)) - &
                            log((x_m * x_m - x_m * q_m + q_m * q_m) / (1.0 - q_m + q_m * q_m)) + &
                            2.0 * sqrt(3.0) * (&
                                    atan((2.0 * x_m - q_m) / (sqrt(3.0) * q_m)) - &
                                            atan((2.0 - q_m) / (sqrt(3.0) * q_m))))
            q_h = sqrt(c_h * c_h - 4.0)
            x_h = zeta_h

            psi_h = -0.5 * b_h * log(1.0 + c_h * x_h + x_h * x_h) + &
                    ((-a_h / q_h) + ((b_h * c_h) / (2.0 * q_h))) * (&
                            log((2.0 * x_h + c_h - q_h) / (2.0 * x_h + c_h + q_h)) - &
                                    log((c_h - q_h) / (c_h + q_h)))

    end subroutine get_psi_stable



    subroutine get_psi_convection(psi_m, psi_h, zeta_m, zeta_h)
        !< @brief Carl et al. 1973 with Grachev et al. 2000 corrections of beta_m, beta_h
        ! ----------------------------------------------------------------------------
        real, intent(out) :: psi_m, psi_h               !< universal functions [n/d]
        real, intent(in) :: zeta_m, zeta_h                       !< = z/L [n/d]

        real y
        ! ----------------------------------------------------------------------------
        ! beta_m = 10, beta_h = 34

        y = (1.0 - beta_m * zeta_m)**(1/3.)
        psi_m = 3.0 * 0.5 *log((y*y + y + 1.0)/3.) - sqrt(3.0) *atan((2.0*y + 1)/sqrt(3.0)) + pi/sqrt(3.0)
        y = (1.0 - beta_h * zeta_h)**(1/3.)
        psi_h = 3.0 * 0.5 *log((y*y + y + 1.0)/3.) - sqrt(3.0) *atan((2.0*y + 1)/sqrt(3.0)) + pi/sqrt(3.0)


    end subroutine

    subroutine get_psi_BD(psi_m, psi_h, zeta_m, zeta_h)
        !< @brief universal functions (momentum) & (heat): neutral case
        ! ----------------------------------------------------------------------------
        real, intent(out) :: psi_m, psi_h   !< universal functions

        real, intent(in) :: zeta_m,zeta_h            !< = z/L
        ! ----------------------------------------------------------------------------

        ! --- local variables
        real :: x_m, x_h
        ! ----------------------------------------------------------------------------

            x_m = (1.0 - alpha_m * zeta_m)**(0.25)
            x_h = (1.0 - alpha_h * zeta_h)**(0.25)

            psi_m = (4.0 * atan(1.0) / 2.0) + 2.0 * log(0.5 * (1.0 + x_m)) + log(0.5 * (1.0 + x_m * x_m)) - 2.0 * atan(x_m)
            psi_h = 2.0 * log(0.5 * (1.0 + x_h * x_h))

     end subroutine

    subroutine get_phi_a(phi_m, phi_h, zeta)
        !< @brief universal functions (momentum) & (heat): neutral case
        ! ----------------------------------------------------------------------------
        real, intent(out) :: phi_m, phi_h   !< universal functions

        real, intent(in) :: zeta            !< = z/L
        ! ----------------------------------------------------------------------------

        ! --- local variables
        real :: x_m, x_h, y
        real :: psi_m_bd,psi_h_bd,psi_m_conv,psi_h_conv
        real :: dpsi_m_bd,dpsi_h_bd,dpsi_m_conv,dpsi_h_conv
        ! ----------------------------------------------------------------------------


            call get_psi_BD(psi_m_bd,psi_h_bd,zeta,zeta)
            call get_psi_convection(psi_m_conv,psi_h_conv,zeta,zeta)

            x_m = (1.0 - alpha_m * zeta)**(0.25)
            x_h = (1.0 - alpha_h * zeta)**(0.25)
            dpsi_m_bd = -(alpha_m/(2.0*(x_m**3))) * (1/(1+x_m) + (x_m-1)/(1+x_m**2))
            dpsi_h_bd = -alpha_h / ((x_h**2)*(1+x_h**2))

            y = (1 - beta_m * zeta)**(1/3.)
            dpsi_m_conv = -beta_m/(y*(y**2 + y + 1))
            y = (1 - beta_h * zeta)**(1/3.)
            dpsi_h_conv = -beta_h/(y*(y**2 + y + 1))

            phi_m = 1.0 - zeta * (dpsi_m_bd/(1.0+zeta**2) - psi_m_bd*2.0*zeta/((1.0+zeta**2)**2) + &
            dpsi_m_conv/(1.0+1.0/(zeta**2)) + 2.0*psi_m_conv/((zeta**3)*((1.0+1.0/(zeta**2))**2)))

            phi_h = 1.0 - zeta * (dpsi_h_bd/(1.0+zeta**2) - psi_h_bd*2.0*zeta/((1.0+zeta**2)**2) + &
            dpsi_h_conv/(1.0+1.0/(zeta**2)) + 2.0*psi_h_conv/((zeta**3)*((1.0+1.0/(zeta**2))**2)))

     end subroutine

    subroutine get_phi_a2(phi_m,phi_h,zeta)
        ! ----------------------------------------------------------------------------
        real, intent(out) :: phi_m, phi_h   !< universal functions

        real, intent(in) :: zeta            !< = z/L
        ! ----------------------------------------------------------------------------

        ! --- local variables
        real :: phi_m_bd,phi_h_bd,phi_m_conv,phi_h_conv

            call get_phi_convection(phi_m_conv, phi_h_conv, zeta)
            call get_phi_BD(phi_m_BD, phi_h_BD, zeta)

            phi_m = (phi_m_BD + (zeta**2) * phi_m_conv) / (1 + zeta**2)
            phi_h = (phi_h_BD + (zeta**2) * phi_h_conv) / (1 + zeta**2)

    end subroutine

    subroutine get_phi_a3(phi_m,phi_h,zeta)
        ! ----------------------------------------------------------------------------
        real, intent(out) :: phi_m, phi_h   !< universal functions

        real, intent(in) :: zeta            !< = z/L
        ! ----------------------------------------------------------------------------

        ! --- local variables
        real :: phi_m_bd,phi_h_bd,phi_m_conv,phi_h_conv
        real :: psi_m_a,psi_h_a,psi_m_conv,psi_h_conv,psi_m_BD,psi_h_BD

            call get_phi_convection(phi_m_conv, phi_h_conv, zeta)
            call get_phi_BD(phi_m_BD, phi_h_BD, zeta)

            call get_psi_convection(psi_m_conv, psi_h_conv, zeta,zeta)
            call get_psi_BD(psi_m_BD, psi_h_BD, zeta,zeta)
!            call get_psi_a(psi_m_a, psi_h_a, zeta,zeta)

           phi_m = (1-phi_m_BD)/(zeta*(1+zeta**2)) + 2*zeta*(psi_m_conv-psi_m_BD)/((1+zeta**2)**2) &
           + zeta*(1-phi_m_conv)/((1+zeta**2)**2)
           phi_h = (1-phi_h_BD)/(zeta*(1+zeta**2)) + 2*zeta*(psi_h_conv-psi_h_BD)/((1+zeta**2)**2) &
           + zeta*(1-phi_h_conv)/((1+zeta**2)**2)

    end subroutine

    subroutine get_phi_BD(phi_m, phi_h, zeta)
        !< @brief stability functions (momentum) & (heat): neutral case
        ! ----------------------------------------------------------------------------
        real, intent(out) :: phi_m, phi_h   !< stability functions

        real, intent(in) :: zeta            !< = z/L
        ! ----------------------------------------------------------------------------

            phi_m = (1.0 - alpha_m * zeta)**(-0.25)
            phi_h = (1.0 - alpha_h * zeta)**(-0.5)

    end subroutine

    subroutine get_phi_convection(phi_m, phi_h, zeta)
        !< @brief stability functions (momentum) & (heat): neutral case
        ! ----------------------------------------------------------------------------
        real, intent(out) :: phi_m, phi_h   !< stability functions

        real, intent(in) :: zeta            !< = z/L
        ! ----------------------------------------------------------------------------

            phi_m = (1.0 - beta_m * zeta)**(-1.0/3.0)
            phi_h = (1.0 - beta_h * zeta)**(-1.0/3.0)

    end subroutine

        !< @brief get dynamic scales
    ! --------------------------------------------------------------------------------
    subroutine get_dynamic_scales(Udyn, Tdyn, Qdyn, zeta, psi_m, psi_h, &
            psi0_m,psi0_h, U, Tsemi, dT, dQ, z, z0_m, z0_t, beta, maxiters)
        ! ----------------------------------------------------------------------------
        real, intent(out) :: Udyn, Tdyn, Qdyn   !< dynamic scales
        real, intent(out) :: zeta   !< = z/L
        real, intent(out) :: psi_m,psi_h,psi0_m,psi0_h

        real, intent(in) :: U                   !< abs(wind speed) at z
        real, intent(in) :: Tsemi               !< semi-sum of temperature at z and at surface
        real, intent(in) :: dT, dQ              !< temperature & humidity difference between z and at surface
        real, intent(in) :: z                   !< constant flux layer height
        real, intent(in) :: z0_m, z0_t          !< roughness parameters
        real, intent(in) :: beta                !< buoyancy parameter

        integer, intent(in) :: maxiters         !< maximum number of iterations
        ! ----------------------------------------------------------------------------

        ! --- local variables
        real :: Linv
        integer :: i
        ! ----------------------------------------------------------------------------


        Udyn = kappa * U / log(z / z0_m)
        Tdyn = kappa * dT * Pr_t_0_inv / log(z / z0_t)
        Qdyn = kappa * dQ * Pr_t_0_inv / log(z / z0_t)
        zeta = 0.0

        ! --- no wind
        if (Udyn < 1e-5) return

        Linv = kappa * beta * (Tdyn + 0.61 * Qdyn * Tsemi) / (Udyn * Udyn)
        zeta = z * Linv
        psi_m = log(z/z0_m)
        psi_h = log(z/z0_t) / Pr_t_0_inv
        psi0_m = log(z/z0_m)
        psi0_h = log(z/z0_t) / Pr_t_0_inv


        ! --- near neutral case
!        if (Linv < 1e-5) return
        do i = 1, maxiters

            call get_psi_a(psi_m, psi_h, zeta, zeta)
            call get_psi_a(psi0_m, psi0_h, z0_m * Linv, z0_t * Linv)

            Udyn = kappa * U / (log(z / z0_m) - (psi_m - psi0_m))
            Tdyn = kappa * dT * Pr_t_0_inv / (log(z / z0_t) - (psi_h - psi0_h))
            Qdyn = kappa * dQ * Pr_t_0_inv / (log(z / z0_t) - (psi_h - psi0_h))

            if (Udyn < 1e-5) exit

            Linv = kappa * beta * (Tdyn + 0.61 * Qdyn * Tsemi) / (Udyn * Udyn)
            zeta = z * Linv
        end do

    end subroutine get_dynamic_scales

    subroutine get_psi_a(psi_m,psi_h,zeta_m, zeta_h)
        ! ----------------------------------------------------------------------------
        real, intent(out) :: psi_m, psi_h   !< universal functions

        real, intent(in) :: zeta_m,zeta_h            !< = z/L
        ! ----------------------------------------------------------------------------

        ! --- local variables
        real :: psi_m_bd,psi_h_bd,psi_m_conv,psi_h_conv

            call get_psi_convection(psi_m_conv, psi_h_conv, zeta_m, zeta_h)
            call get_psi_BD(psi_m_BD, psi_h_BD, zeta_m, zeta_h)

            psi_m = (psi_m_BD + (zeta_m**2) * psi_m_conv) / (1 + zeta_m**2)
            psi_h = (psi_h_BD + (zeta_h**2) * psi_h_conv) / (1 + zeta_h**2)

    end subroutine

    subroutine get_convection_lim(zeta_lim, Rib_lim, f_m_lim, f_h_lim, &
            h0_m, h0_t, B)
        ! ----------------------------------------------------------------------------
        real, intent(out) :: zeta_lim           !< limiting value of z/L
        real, intent(out) :: Rib_lim            !< limiting value of Ri-bulk
        real, intent(out) :: f_m_lim, f_h_lim   !< limiting values of universal functions shortcuts

        real, intent(in) :: h0_m, h0_t          !< = z/z0_m, z/z0_h [n/d]
        real, intent(in) :: B                   !< = log(z0_m / z0_h) [n/d]
        ! ----------------------------------------------------------------------------

        ! --- local variables
        real :: psi_m, psi_h
        real :: f_m, f_h
        real :: c
        ! ----------------------------------------------------------------------------

        ! --- define limiting value of zeta = z / L
        c = (Pr_t_inf_inv / Pr_t_0_inv)**4
        zeta_lim = (2.0 * alpha_h - c * alpha_m - &
                sqrt((c * alpha_m)**2 + 4.0 * c * alpha_h * (alpha_h - alpha_m))) / (2.0 * alpha_h**2)

        f_m_lim = f_m_conv(zeta_lim)
        f_h_lim = f_h_conv(zeta_lim)

        ! --- universal functions
        f_m = zeta_lim / h0_m
        f_h = zeta_lim / h0_t
        if (abs(B) < 1.0e-10) f_h = f_m

        f_m = (1.0 - alpha_m * f_m)**0.25
        f_h = sqrt(1.0 - alpha_h_fix * f_h)

        psi_m = 2.0 * (atan(f_m_lim) - atan(f_m)) + alog((f_m_lim - 1.0) * (f_m + 1.0)/((f_m_lim + 1.0) * (f_m - 1.0)))
        psi_h = alog((f_h_lim - 1.0) * (f_h + 1.0)/((f_h_lim + 1.0) * (f_h - 1.0))) / Pr_t_0_inv

        ! --- bulk Richardson number
        Rib_lim = zeta_lim * psi_h / (psi_m * psi_m)

    end subroutine

    ! convection universal functions shortcuts
    ! --------------------------------------------------------------------------------
    function f_m_conv(zeta)
        ! ----------------------------------------------------------------------------
        real :: f_m_conv
        real, intent(in) :: zeta
        ! ----------------------------------------------------------------------------

        f_m_conv = (1.0 - alpha_m * zeta)**0.25

    end function f_m_conv

    function f_h_conv(zeta)
        ! ----------------------------------------------------------------------------
        real :: f_h_conv
        real, intent(in) :: zeta
        ! ----------------------------------------------------------------------------

        f_h_conv = (1.0 - alpha_h * zeta)**0.5

    end function f_h_conv

    subroutine get_zeta_conv(zeta,Rib,z,z0m,z0t)
        !< @brief Srivastava and Sharan 2017, Abdella and Assefa 2005
        ! ----------------------------------------------------------------------------
        real, intent(out) :: zeta                       !< = z/L [n/d]
        real, intent(in) :: Rib              !
        real, intent(in) :: z,z0m,z0t               !

        real A,a0,a1,a2,r,q,s1,s2,theta,delta
        real ksi_m,ksi_h,ksi_m_0,ksi_m_inf,ksi_h_0,ksi_h_inf
        real f_m_inf,f_h_inf
        real psi_m_zeta,psi_m_zeta0,psi_h_zeta,psi_h_zeta0

        ! ----------------------------------------------------------------------------

        A = ( 1 / Pr_t_0_inv ) * ( (1 - z0m/z)**2) * log(z/z0t) / ( (1 - z0t/z) * ((log(z/z0m))**2) )

        call get_psi_convection(psi_m_zeta,psi_h_zeta,Rib/A, Rib/A)
        call get_psi_convection(psi_m_zeta0,psi_h_zeta0, (z0m/z) * (Rib/A),(z0t/z) * (Rib/A))

        f_m_inf = 1 - (psi_m_zeta - psi_m_zeta0) / log(z/z0m)
        f_h_inf = 1 - (psi_h_zeta - psi_h_zeta0) / log(z/z0t)

        ksi_m_0 = ((z0m / z) - 1.0) / log(z0m/z)
        ksi_h_0 = ((z0t / z) - 1.0) / log(z0t/z)

        ksi_m_inf = (A / (beta_m * Rib)) * (1.0 - 1.0 / (f_m_inf**4))
        ksi_h_inf = (A / (beta_h * Rib)) * (1.0 - ((1.0 / Pr_t_0_inv)**2) / (f_h_inf**2))

        ksi_m = cc1 * ksi_m_inf + cc2 * ksi_m_0
        ksi_h = cc3 * ksi_h_inf + cc4 * ksi_h_0

        a0 = (1 / (beta_m * ksi_m)) * ((Rib / A)**2)
        a1 = -1 * (beta_h * ksi_h / (beta_m * ksi_m)) * ((Rib / A)**2)
        a2 = -1 / (beta_m * ksi_m)

        r = (9.0 * (a1 * a2 - 3 * a0) - 2.0 * (a2**3)) / 54.0
        q = (3.0 * a1 - a2*a2) / 9.0
        delta = q**3 + r**2
        s1 = (r + sqrt(delta))**(1./3.)
        s2 = (r - sqrt(delta))**(1./3.)
        theta = 1.0 / cos(r / sqrt(-1 * (q**3)))
        if(delta <= 0.0)then
        zeta = 2.0 * sqrt(-1.0 * q) * cos((theta + 2.0 * pi)/3.0) + 1 /(3.0 * beta_m * ksi_m)
        else
        s1 = cmplx(r + delta**0.5)**(1./3.)
        s2 = cmplx(r - delta**0.5)**(1./3.)
        zeta = -1*(real(real(s1)) + real(real(s2)) + 1.0 /(3.0 * beta_m * ksi_m))
        endif

    end subroutine


end module sfx_sheba_coare