#include <iostream>
#include <cmath>

#include "sfx-sheba.h"
#ifdef INCLUDE_CUDA
    #include "sfx-memory-processing.cuh"
#endif

#include "sfx-memory-processing.h"
#include "sfx-surface.cuh"
#include "sfx-model-compute-subfunc.cuh"

template<typename T, MemType memIn, MemType memOut, MemType RunMem >
FluxShebaBase<T, memIn, memOut, RunMem>::FluxShebaBase(sfxDataVecTypeC* sfx_in,
                meteoDataVecTypeC* meteo_in,
                const sfx_sheba_param_C model_param_in, 
                const sfx_surface_param surface_param_in,
                const sfx_sheba_numericsType_C numerics_in,
                const sfx_phys_constants phys_constants_in,
                const int grid_size_in) : ModelBase<T, memIn, memOut, RunMem>(sfx_in, meteo_in, grid_size_in)
{
    surface = surface_param_in;
    phys = phys_constants_in;
    model = model_param_in;
    numerics = numerics_in;
}

template<typename T, MemType memIn, MemType memOut, MemType RunMem >
FluxShebaBase<T, memIn, memOut, RunMem>::~FluxShebaBase() {}

template<typename T, MemType memIn, MemType memOut >
void FluxSheba<T, memIn, memOut, MemType::CPU>::compute_flux()
{
    T h, U, dT, Tsemi, dQ, z0_m;
    T z0_t, B, h0_m, h0_t, u_dyn0, Re, 
    zeta, Rib, Udyn, Tdyn, Qdyn, phi_m, phi_h,
    Km, Pr_t_inv, Cm, Ct;
    int surface_type;

    for (int step = 0; step < grid_size; step++)
    {
        U = meteo.U[step];
        Tsemi = meteo.Tsemi[step];
        dT = meteo.dT[step];
        dQ = meteo.dQ[step];
        h = meteo.h[step];
        z0_m = meteo.z0_m[step];

        surface_type = z0_m < 0.0 ? surface.surface_ocean : surface.surface_land;

        if (surface_type == surface.surface_ocean) 
        {
            get_charnock_roughness(z0_m, u_dyn0, U, h, surface, numerics.maxiters_charnock);
            h0_m = h / z0_m;
        }
        if (surface_type == surface.surface_land) 
        {
            h0_m = h / z0_m;
            u_dyn0 = U * model.kappa / logf(h0_m);
        }

        Re = u_dyn0 * z0_m / phys.nu_air;
        get_thermal_roughness(z0_t, B, z0_m, Re, surface, surface_type);

        // --- define relative height [thermal]
        h0_t = h / z0_t;

        // --- define Ri-bulk
        Rib = (phys.g / Tsemi) * h * (dT + 0.61e0 * Tsemi * dQ) / (U*U);

        // --- get the fluxes
        // ----------------------------------------------------------------------------
        get_dynamic_scales(Udyn, Tdyn, Qdyn, zeta, 
            U, Tsemi, dT, dQ, h, z0_m, z0_t, (phys.g / Tsemi), model, 10);
        // ----------------------------------------------------------------------------

        get_phi(phi_m, phi_h, zeta, model);
        // ----------------------------------------------------------------------------

        // --- define transfer coeff. (momentum) & (heat)
        Cm = 0.0;
        if (U > 0.0)
            Cm = Udyn / U;
        Ct = 0.0;
        if (fabs(dT) > 0.0) 
            Ct = Tdyn / dT;

        // --- define eddy viscosity & inverse Prandtl number
        Km = model.kappa * Cm * U * h / phi_m;
        Pr_t_inv = phi_m / phi_h;

        sfx.zeta[step]         = zeta;
        sfx.Rib[step]          = Rib;
        sfx.Re[step]           = Re;
        sfx.B[step]            = B;
        sfx.z0_m[step]         = z0_m;
        sfx.z0_t[step]         = z0_t;
        sfx.Rib_conv_lim[step] = T(0.0);
        sfx.Cm[step]           = Cm;
        sfx.Ct[step]           = Ct;
        sfx.Km[step]           = Km;
        sfx.Pr_t_inv[step]     = Pr_t_inv;
    }

    if(MemType::CPU != memOut)
    {
        const size_t new_size = grid_size * sizeof(T);
        memproc::memcopy<memOut, MemType::CPU>((void*&)res_sfx->zeta, (void*&)sfx.zeta, new_size);
        memproc::memcopy<memOut, MemType::CPU>((void*&)res_sfx->Rib, (void*&)sfx.Rib, new_size);
        memproc::memcopy<memOut, MemType::CPU>((void*&)res_sfx->Re, (void*&)sfx.Re, new_size);
        memproc::memcopy<memOut, MemType::CPU>((void*&)res_sfx->B, (void*&)sfx.B, new_size);
        memproc::memcopy<memOut, MemType::CPU>((void*&)res_sfx->z0_m, (void*&)sfx.z0_m, new_size);
        memproc::memcopy<memOut, MemType::CPU>((void*&)res_sfx->z0_t, (void*&)sfx.z0_t, new_size);
        memproc::memcopy<memOut, MemType::CPU>((void*&)res_sfx->Rib_conv_lim, (void*&)sfx.Rib_conv_lim, new_size);
        memproc::memcopy<memOut, MemType::CPU>((void*&)res_sfx->Cm, (void*&)sfx.Cm, new_size);
        memproc::memcopy<memOut, MemType::CPU>((void*&)res_sfx->Ct, (void*&)sfx.Ct, new_size);
        memproc::memcopy<memOut, MemType::CPU>((void*&)res_sfx->Km, (void*&)sfx.Km, new_size);
        memproc::memcopy<memOut, MemType::CPU>((void*&)res_sfx->Pr_t_inv, (void*&)sfx.Pr_t_inv, new_size);
    }
}

template class FluxSheba<float, MemType::CPU, MemType::CPU, MemType::CPU>;
template class FluxShebaBase<float, MemType::CPU, MemType::CPU, MemType::CPU>;

#ifdef INCLUDE_CUDA
    template class FluxShebaBase<float, MemType::GPU, MemType::GPU, MemType::GPU>;
    template class FluxShebaBase<float, MemType::GPU, MemType::GPU, MemType::CPU>;
    template class FluxShebaBase<float, MemType::GPU, MemType::CPU, MemType::GPU>;
    template class FluxShebaBase<float, MemType::CPU, MemType::GPU, MemType::GPU>;
    template class FluxShebaBase<float, MemType::CPU, MemType::CPU, MemType::GPU>;
    template class FluxShebaBase<float, MemType::CPU, MemType::GPU, MemType::CPU>;
    template class FluxShebaBase<float, MemType::GPU, MemType::CPU, MemType::CPU>;

    template class FluxSheba<float, MemType::GPU, MemType::GPU, MemType::GPU>;
    template class FluxSheba<float, MemType::GPU, MemType::GPU, MemType::CPU>;
    template class FluxSheba<float, MemType::GPU, MemType::CPU, MemType::GPU>;
    template class FluxSheba<float, MemType::CPU, MemType::GPU, MemType::GPU>;
    template class FluxSheba<float, MemType::CPU, MemType::CPU, MemType::GPU>;
    template class FluxSheba<float, MemType::CPU, MemType::GPU, MemType::CPU>;
    template class FluxSheba<float, MemType::GPU, MemType::CPU, MemType::CPU>;
#endif