#pragma once 
#include "../includeC/sfx_data.h"
#include "../includeCU/sfx_math.cuh"

template<typename T>
FUCNTION_DECLARATION_SPECIFIER void get_convection_lim(T &zeta_lim, T &Rib_lim, T &f_m_lim, T &f_h_lim,
    const T h0_m, const T h0_t, const T B,
    const sfx_esm_param& param)
{
    T psi_m, psi_h, f_m, f_h, c;

    c = pow(param.Pr_t_inf_inv / param.Pr_t_0_inv, 4);
    zeta_lim = (2.0 * param.alpha_h - c * param.alpha_m - 
                sqrt( (c * param.alpha_m)*(c * param.alpha_m) + 4.0 * c * param.alpha_h * (param.alpha_h - param.alpha_m))) / (2.0 * param.alpha_h*param.alpha_h);

    f_m_lim = pow(1.0 - param.alpha_m * zeta_lim, 0.25);
    f_h_lim = sqrt(1.0 - param.alpha_h * zeta_lim);

    f_m = zeta_lim / h0_m;
    f_h = zeta_lim / h0_t;
    if (fabs(B) < 1.0e-10) f_h = f_m;

    f_m = pow(1.0 - param.alpha_m * f_m, 0.25);
    f_h = sqrt(1.0 - param.alpha_h_fix * f_h);

    psi_m = 2.0 * (atan(f_m_lim) - atan(f_m)) + log((f_m_lim - 1.0) * (f_m + 1.0)/((f_m_lim + 1.0) * (f_m - 1.0)));
    psi_h = log((f_h_lim - 1.0) * (f_h + 1.0)/((f_h_lim + 1.0) * (f_h - 1.0))) / param.Pr_t_0_inv;

    Rib_lim = zeta_lim * psi_h / (psi_m * psi_m);
}

template<typename T>
FUCNTION_DECLARATION_SPECIFIER void get_psi_stable(T &psi_m, T &psi_h, T &zeta,
    const T Rib, const T h0_m, const T h0_t, const T B,
    const sfx_esm_param& param)
{
    T Rib_coeff, psi0_m, psi0_h, phi, c;
    
    psi0_m = log(h0_m);
    psi0_h = B / psi0_m;

    Rib_coeff = param.beta_m * Rib;
    c = (psi0_h + 1.0) / param.Pr_t_0_inv - 2.0 * Rib_coeff;
    zeta = psi0_m * (sqrt(c*c + 4.0 * Rib_coeff * (1.0 - Rib_coeff)) - c) / (2.0 * param.beta_m * (1.0 - Rib_coeff));

    phi = param.beta_m * zeta;

    psi_m = psi0_m + phi;
    psi_h = (psi0_m + B) / param.Pr_t_0_inv + phi;
}

template<typename T>
FUCNTION_DECLARATION_SPECIFIER void get_psi_convection(T &psi_m, T &psi_h, T &zeta, 
    const T Rib, const T h0_m, const T h0_t, const T B, 
    const T zeta_conv_lim, const T f_m_conv_lim, const T f_h_conv_lim,
    const sfx_esm_param& param,
    const sfx_esm_numericsTypeC& numerics)
{
    T zeta0_m, zeta0_h, f0_m, f0_h, p_m, p_h, a_m, a_h, c_lim, f;

    p_m = 2.0 * atan(f_m_conv_lim) + log((f_m_conv_lim - 1.0) / (f_m_conv_lim + 1.0));
    p_h = log((f_h_conv_lim - 1.0) / (f_h_conv_lim + 1.0));

    zeta = zeta_conv_lim;

    for (int i = 1; i <= numerics.maxiters_convection + 1; i++)
    {
        zeta0_m = zeta / h0_m;
        zeta0_h = zeta / h0_t;
        if (fabs(B) < 1.0e-10) 
            zeta0_h = zeta0_m;

        f0_m = pow(1.0 - param.alpha_m * zeta0_m, 0.25);
        f0_h = sqrt(1.0 - param.alpha_h_fix * zeta0_h);

        a_m = -2.0*atan(f0_m) + log((f0_m + 1.0)/(f0_m - 1.0));
        a_h = log((f0_h + 1.0)/(f0_h - 1.0));

        c_lim = pow(zeta_conv_lim / zeta, 1.0 / 3.0);
        f = 3.0 * (1.0 - c_lim);

        psi_m = f / f_m_conv_lim + p_m + a_m;
        psi_h = (f / f_h_conv_lim + p_h + a_h) / param.Pr_t_0_inv;

        if (i == numerics.maxiters_convection + 1) 
            break;

        zeta = Rib * psi_m * psi_m / psi_h;
    }
}

template<typename T>
FUCNTION_DECLARATION_SPECIFIER void get_psi_neutral(T &psi_m, T &psi_h, T &zeta,   
    const T h0_m, const T h0_t, const T B,
    const sfx_esm_param& param)
{
    zeta = 0.0;
    psi_m = log(h0_m);
    psi_h = log(h0_t) / param.Pr_t_0_inv;
    if (fabs(B) < 1.0e-10) 
        psi_h = psi_m / param.Pr_t_0_inv;
}

template<typename T>
FUCNTION_DECLARATION_SPECIFIER void get_psi_semi_convection(T &psi_m, T &psi_h, T &zeta,
    const T Rib, const T h0_m, const T h0_t, const T B, 
    const sfx_esm_param& param,
    const sfx_esm_numericsTypeC& numerics)
{
    T zeta0_m, zeta0_h, f0_m, f0_h, f_m, f_h;

    psi_m = log(h0_m);
    psi_h = log(h0_t);

    if (fabs(B) < 1.0e-10) 
        psi_h = psi_m;

    zeta = Rib * param.Pr_t_0_inv * psi_m * psi_m / psi_h;

    for (int i = 1; i <= numerics.maxiters_convection + 1; i++)
    {
        zeta0_m = zeta / h0_m;
        zeta0_h = zeta / h0_t;
        if (fabs(B) < 1.0e-10) 
            zeta0_h = zeta0_m;

        f_m = pow(1.0 - param.alpha_m * zeta, 0.25e0);
        f_h = sqrt(1.0 - param.alpha_h_fix * zeta);

        f0_m = pow(1.0 - param.alpha_m * zeta0_m, 0.25e0);
        f0_h = sqrt(1.0 - param.alpha_h_fix * zeta0_h);

        f0_m = sfx_math::max(f0_m, T(1.000001e0));
        f0_h = sfx_math::max(f0_h, T(1.000001e0));

        psi_m = log((f_m - 1.0e0)*(f0_m + 1.0e0)/((f_m + 1.0e0)*(f0_m - 1.0e0))) + 2.0e0*(atan(f_m) - atan(f0_m));
        psi_h = log((f_h - 1.0e0)*(f0_h + 1.0e0)/((f_h + 1.0e0)*(f0_h - 1.0e0))) / param.Pr_t_0_inv;

        if (i == numerics.maxiters_convection + 1) 
            break;

        zeta = Rib * psi_m * psi_m / psi_h;
    }
}