#include <cmath>
#include <iostream>
#include "../includeCXX/sfx_compute_sheba.h"

template<typename T>
void compute_flux_sheba_cpu(T *zeta_, T *Rib_, T *Re_, T *B_, T *z0_m_, T *z0_t_, T *Rib_conv_lim_, T *Cm_, T *Ct_, T *Km_, T *Pr_t_inv_,
    const T *U_, const T *dT_, const T *Tsemi_, const T *dQ_, const T *h_, const T *in_z0_m_,
    const T kappa, const T Pr_t_0_inv, const T Pr_t_inf_inv, 
    const T alpha_m, const T alpha_h, const T alpha_h_fix, 
    const T beta_m, const T beta_h, const T Rib_max, const T Re_rough_min, 
    const T B1_rough, const T B2_rough,
    const T B_max_land, const T B_max_ocean, const T B_max_lake,
    const T gamma_c, const T Re_visc_min,
    const T Pr_m, const T nu_air, const T g, 
    const int maxiters_charnock, const int maxiters_convection, 
    const int grid_size)
{
    T h, U, dT, Tsemi, dQ, z0_m;
    T z0_t, B, h0_m, h0_t, u_dyn0, Re, 
    zeta, Rib, Udyn, Tdyn, Qdyn, phi_m, phi_h,
    Km, Pr_t_inv, Cm, Ct;

    int surface_type;

    for (int step = 0; step < grid_size; step++)
    {
        U = U_[step];
        Tsemi = Tsemi_[step];
        dT = dT_[step];
        dQ = dQ_[step];
        h = h_[step];
        z0_m = in_z0_m_[step];

        if (z0_m < 0.0) surface_type = 0;
        else            surface_type = 1;

        if (surface_type == surface_ocean) 
        {
            get_charnock_roughness(z0_m, u_dyn0, U, h, numerics%maxiters_charnock);
            h0_m = h / z0_m;
        }
        if (surface_type == surface_land) 
        {
            h0_m = h / z0_m;
            u_dyn0 = U * kappa / log(h0_m);
        }

        Re = u_dyn0 * z0_m / nu_air;
        get_thermal_roughness(z0_t, B, z0_m, Re, surface_type);

        // --- define relative height [thermal]
        h0_t = h / z0_t;

        // --- define Ri-bulk
        Rib = (g / Tsemi) * h * (dT + 0.61e0 * Tsemi * dQ) / (U*U);

        // --- get the fluxes
        // ----------------------------------------------------------------------------
        get_dynamic_scales(Udyn, Tdyn, Qdyn, zeta, U, Tsemi, dT, dQ, h, z0_m, z0_t, (g / Tsemi), 10);
        // ----------------------------------------------------------------------------

        get_phi(phi_m, phi_h, zeta);
        // ----------------------------------------------------------------------------

        // --- define transfer coeff. (momentum) & (heat)
        Cm = 0.0;
        if (U > 0.0)
            Cm = Udyn / U;
        Ct = 0.0;
        if (fabs(dT) > 0.0) 
            Ct = Tdyn / dT;

        // --- define eddy viscosity & inverse Prandtl number
        Km = kappa * Cm * U * h / phi_m;
        Pr_t_inv = phi_m / phi_h;

        zeta_[step]         = zeta;
        Rib_[step]          = Rib;
        Re_[step]           = Re;
        B_[step]            = B;
        z0_m_[step]         = z0_m;
        z0_t_[step]         = z0_t;
        Rib_conv_lim_[step] = Rib_conv_lim;
        Cm_[step]           = Cm;
        Ct_[step]           = Ct;
        Km_[step]           = Km;
        Pr_t_inv_[step]     = Pr_t_inv;
    }
}

template void compute_flux_sheba_cpu(float *zeta_, float *Rib_, float *Re_, float *B_, float *z0_m_, float *z0_t_, float *Rib_conv_lim_, float *Cm_, float *Ct_, float *Km_, float *Pr_t_inv_,
    const float *U, const float *dt, const float *T_semi, const float *dq, const float *H, const float *in_z0_m,
    const float kappa, const float Pr_t_0_inv, const float Pr_t_inf_inv, 
    const float alpha_m, const float alpha_h, const float alpha_h_fix, 
    const float beta_m, const float beta_h, const float Rib_max, const float Re_rough_min, 
    const float B1_rough, const float B2_rough,
    const float B_max_land, const float B_max_ocean, const float B_max_lake,
    const float gamma_c, const float Re_visc_min,
    const float Pr_m, const float nu_air, const float g, 
    const int maxiters_charnock, const int maxiters_convection, 
    const int grid_size);
template void compute_flux_sheba_cpu(double *zeta_, double *Rib_, double *Re_, double *B_, double *z0_m_, double *z0_t_, double *Rib_conv_lim_, double *Cm_, double *Ct_, double *Km_, double *Pr_t_inv_,
    const double *U, const double *dt, const double *T_semi, const double *dq, const double *H, const double *in_z0_m, 
    const double kappa, const double Pr_t_0_inv, const double Pr_t_inf_inv, 
    const double alpha_m, const double alpha_h, const double alpha_h_fix, 
    const double beta_m, const double beta_h, const double Rib_max, const double Re_rough_min, 
    const double B1_rough, const double B2_rough,
    const double B_max_land, const double B_max_ocean, const double B_max_lake,
    const double gamma_c, const double Re_visc_min,
    const double Pr_m, const double nu_air, const double g, 
    const int maxiters_charnock, const int maxiters_convection, 
    const int grid_size);