#include <cuda.h>
#include <cuda_runtime_api.h>

#include "sfx-sheba.h"
#include "sfx-model-compute-subfunc.cuh"
#include "sfx-surface.cuh"
#include "sfx-memory-processing.cuh"

namespace sfx_kernel
{
    template<typename T>
    __global__ void compute_flux(sfxDataVecTypeC sfx,
        meteoDataVecTypeC meteo,
        const sfx_sheba_param_C model, 
        const sfx_surface_param surface,
        const sfx_sheba_numericsType_C numerics,
        const sfx_phys_constants phys,
        const int grid_size);

    template<typename T>
    __global__ void noit_compute_flux(sfxDataVecTypeC sfx,
        meteoDataVecTypeC meteo,
        const sfx_sheba_noit_param_C model, 
        const sfx_surface_param surface,
        const sfx_sheba_noit_numericsType_C numerics,
        const sfx_phys_constants phys,
        const int grid_size);
}

template<typename T>
__global__ void sfx_kernel::compute_flux(sfxDataVecTypeC sfx,
    meteoDataVecTypeC meteo,
    const sfx_sheba_param_C model, 
    const sfx_surface_param surface,
    const sfx_sheba_numericsType_C numerics,
    const sfx_phys_constants phys,
    const int grid_size)
{
    const int index = blockIdx.x * blockDim.x + threadIdx.x;
    T h, U, dT, Tsemi, dQ, z0_m;
    T z0_t, B, h0_m, h0_t, u_dyn0, Re, 
    zeta, Rib, Udyn, Tdyn, Qdyn, phi_m, phi_h,
    Km, Pr_t_inv, Cm, Ct;
    int surface_type;

    if(index < grid_size)
    {
        U = meteo.U[index];
        Tsemi = meteo.Tsemi[index];
        dT = meteo.dT[index];
        dQ = meteo.dQ[index];
        h = meteo.h[index];
        z0_m = meteo.z0_m[index];

        surface_type = z0_m < 0.0 ? surface.surface_ocean : surface.surface_land;

        if (surface_type == surface.surface_ocean) 
        {
            get_charnock_roughness(z0_m, u_dyn0, U, h, surface, numerics.maxiters_charnock);
            h0_m = h / z0_m;
        }
        if (surface_type == surface.surface_land) 
        {
            h0_m = h / z0_m;
            u_dyn0 = U * model.kappa / logf(h0_m);
        }

        Re = u_dyn0 * z0_m / phys.nu_air;
        get_thermal_roughness(z0_t, B, z0_m, Re, surface, surface_type);

        // --- define relative height [thermal]
        h0_t = h / z0_t;

        // --- define Ri-bulk
        Rib = (phys.g / Tsemi) * h * (dT + 0.61e0 * Tsemi * dQ) / (U*U);

        // --- get the fluxes
        // ----------------------------------------------------------------------------
        get_dynamic_scales(Udyn, Tdyn, Qdyn, zeta, 
            U, Tsemi, dT, dQ, h, z0_m, z0_t, (phys.g / Tsemi), model, 10);
        // ----------------------------------------------------------------------------

        get_phi(phi_m, phi_h, zeta, model);
        // ----------------------------------------------------------------------------

        // --- define transfer coeff. (momentum) & (heat)
        Cm = 0.0;
        if (U > 0.0)
            Cm = Udyn / U;
        Ct = 0.0;
        if (fabsf(dT) > 0.0) 
            Ct = Tdyn / dT;

        // --- define eddy viscosity & inverse Prandtl number
        Km = model.kappa * Cm * U * h / phi_m;
        Pr_t_inv = phi_m / phi_h;

        sfx.zeta[index]         = zeta;
        sfx.Rib[index]          = Rib;
        sfx.Re[index]           = Re;
        sfx.B[index]            = B;
        sfx.z0_m[index]         = z0_m;
        sfx.z0_t[index]         = z0_t;
        sfx.Rib_conv_lim[index] = T(0.0);
        sfx.Cm[index]           = Cm;
        sfx.Ct[index]           = Ct;
        sfx.Km[index]           = Km;
        sfx.Pr_t_inv[index]     = Pr_t_inv;
    }
}

template<typename T>
__global__ void sfx_kernel::noit_compute_flux(sfxDataVecTypeC sfx,
    meteoDataVecTypeC meteo,
    const sfx_sheba_noit_param_C model, 
    const sfx_surface_param surface,
    const sfx_sheba_noit_numericsType_C numerics,
    const sfx_phys_constants phys,
    const int grid_size)
{
    const int index = blockIdx.x * blockDim.x + threadIdx.x;
    T h, U, dT, Tsemi, dQ, z0_m;
    T z0_t, B, h0_m, h0_t, u_dyn0, Re, 
    zeta, Rib, zeta_conv_lim, Rib_conv_lim, 
    f_m_conv_lim, f_h_conv_lim,
    psi_m, psi_h, 
    psi0_m, psi0_h,
    Udyn, Tdyn, Qdyn, 
    phi_m, phi_h,
    Km, Pr_t_inv, Cm, Ct,
    fval;
    int surface_type;

    if(index < grid_size)
    {
        U = meteo.U[index];
        Tsemi = meteo.Tsemi[index];
        dT = meteo.dT[index];
        dQ = meteo.dQ[index];
        h = meteo.h[index];
        z0_m = meteo.z0_m[index];

        surface_type = z0_m < 0.0 ? surface.surface_ocean : surface.surface_land;

        if (surface_type == surface.surface_ocean) 
        {
            get_charnock_roughness(z0_m, u_dyn0, U, h, surface, numerics.maxiters_charnock);
            h0_m = h / z0_m;
        }
        if (surface_type == surface.surface_land) 
        {
            h0_m = h / z0_m;
            u_dyn0 = U * model.kappa / logf(h0_m);
        }

        Re = u_dyn0 * z0_m / phys.nu_air;
        get_thermal_roughness(z0_t, B, z0_m, Re, surface, surface_type);

        // --- define relative height [thermal]
        h0_t = h / z0_t;

        // --- define Ri-bulk
        Rib = (phys.g / Tsemi) * h * (dT + 0.61e0 * Tsemi * dQ) / (U*U);

        get_convection_lim(zeta_conv_lim, Rib_conv_lim, f_m_conv_lim, f_h_conv_lim, h0_m, h0_t, B, model);

        // --- get the fluxes
        // ----------------------------------------------------------------------------
        if (Rib > 0.0)
        {
            // --- stable stratification block

            // --- restrict bulk Ri value
            // *: note that this value is written to output
            // Rib = min(Rib, Rib_max)
            get_zeta(zeta, Rib, h, z0_m, z0_t, model);
            get_psi_stable(psi_m, psi_h, zeta, zeta, model);
            get_psi_stable(psi0_m, psi0_h, zeta * z0_m / h, zeta * z0_t / h, model);

            phi_m = 1.0 + (model.a_m * zeta * powf(1.0 + zeta, 1.0 / 3.0) ) / (1.0 + model.b_m * zeta);
            phi_h = 1.0 + (model.a_h * zeta + model.b_h * zeta * zeta) / (1.0 + model.c_h * zeta + zeta * zeta);

            Udyn = model.kappa * U / (logf(h / z0_m) - (psi_m - psi0_m));
            Tdyn = model.kappa * dT * model.Pr_t_0_inv / (logf(h / z0_t) - (psi_h - psi0_h));
        }
        else if (Rib < Rib_conv_lim) 
        {    
            // --- strong instability block

            get_psi_convection(psi_m, psi_h, zeta, Rib,
                zeta_conv_lim, f_m_conv_lim, f_h_conv_lim, h0_m, h0_t, B, model, numerics.maxiters_convection);

            fval = powf(zeta_conv_lim / zeta, 1.0/3.0);
            phi_m = fval / f_m_conv_lim;
            phi_h = fval / (model.Pr_t_0_inv * f_h_conv_lim);
        }
        else if (Rib > -0.001)
        { 
            // --- nearly neutral [-0.001, 0] block

            get_psi_neutral(psi_m, psi_h, h0_m, h0_t, B, model);

            zeta = 0.0;
            phi_m = 1.0;
            phi_h = 1.0 / model.Pr_t_0_inv;
        }
        else
        {
            // --- weak & semistrong instability block

            get_psi_semi_convection(psi_m, psi_h, zeta, Rib, h0_m, h0_t, B, model, numerics.maxiters_convection);

            phi_m = powf(1.0 - model.alpha_m * zeta, -0.25);
            phi_h = 1.0 / (model.Pr_t_0_inv * sqrtf(1.0 - model.alpha_h_fix * zeta));
        } 
        // ----------------------------------------------------------------------------

        // --- define transfer coeff. (momentum) & (heat)
        if(Rib > 0)
        {    
            Cm = 0.0;
            if (U > 0.0)
                Cm = Udyn / U;
        
            Ct = 0.0;
            if (fabs(dT) > 0.0)
                Ct = Tdyn / dT;
        }
        else
        { 
            Cm = model.kappa / psi_m;
            Ct = model.kappa / psi_h;
        }

        // --- define eddy viscosity & inverse Prandtl number
        Km = model.kappa * Cm * U * h / phi_m;
        Pr_t_inv = phi_m / phi_h;

        sfx.zeta[index]         = zeta;
        sfx.Rib[index]          = Rib;
        sfx.Re[index]           = Re;
        sfx.B[index]            = B;
        sfx.z0_m[index]         = z0_m;
        sfx.z0_t[index]         = z0_t;
        sfx.Rib_conv_lim[index] = Rib_conv_lim;
        sfx.Cm[index]           = Cm;
        sfx.Ct[index]           = Ct;
        sfx.Km[index]           = Km;
        sfx.Pr_t_inv[index]     = Pr_t_inv;
    }
}

template<typename T, MemType memIn, MemType memOut >
void FluxSheba<T, memIn, memOut, MemType::GPU>::compute_flux(const sfx_sheba_param_C model, 
    const sfx_surface_param surface,
    const sfx_sheba_numericsType_C numerics,
    const sfx_phys_constants phys)
{
    const int BlockCount = int(ceil(float(grid_size) / 1024.0));
    dim3 cuBlock = dim3(1024, 1, 1);
	dim3 cuGrid = dim3(BlockCount, 1, 1);

    sfx_kernel::compute_flux<T><<<cuGrid, cuBlock>>>(sfx, meteo, model, 
                                                    surface, numerics, phys, grid_size);

    if(MemType::GPU != memOut)
    {
        const size_t new_size = grid_size * sizeof(T);
        memproc::memcopy<memOut, MemType::GPU>((void*&)res_sfx->zeta, (void*&)sfx.zeta, new_size);
        memproc::memcopy<memOut, MemType::GPU>((void*&)res_sfx->Rib, (void*&)sfx.Rib, new_size);
        memproc::memcopy<memOut, MemType::GPU>((void*&)res_sfx->Re, (void*&)sfx.Re, new_size);
        memproc::memcopy<memOut, MemType::GPU>((void*&)res_sfx->B, (void*&)sfx.B, new_size);
        memproc::memcopy<memOut, MemType::GPU>((void*&)res_sfx->z0_m, (void*&)sfx.z0_m, new_size);
        memproc::memcopy<memOut, MemType::GPU>((void*&)res_sfx->z0_t, (void*&)sfx.z0_t, new_size);
        memproc::memcopy<memOut, MemType::GPU>((void*&)res_sfx->Rib_conv_lim, (void*&)sfx.Rib_conv_lim, new_size);
        memproc::memcopy<memOut, MemType::GPU>((void*&)res_sfx->Cm, (void*&)sfx.Cm, new_size);
        memproc::memcopy<memOut, MemType::GPU>((void*&)res_sfx->Ct, (void*&)sfx.Ct, new_size);
        memproc::memcopy<memOut, MemType::GPU>((void*&)res_sfx->Km, (void*&)sfx.Km, new_size);
        memproc::memcopy<memOut, MemType::GPU>((void*&)res_sfx->Pr_t_inv, (void*&)sfx.Pr_t_inv, new_size);
    }
}

template<typename T, MemType memIn, MemType memOut >
void FluxSheba<T, memIn, memOut, MemType::GPU>::noit_compute_flux(const sfx_sheba_noit_param_C model, 
    const sfx_surface_param surface,
    const sfx_sheba_noit_numericsType_C numerics,
    const sfx_phys_constants phys)
{
    const int BlockCount = int(ceil(float(grid_size) / 512.0));
    dim3 cuBlock = dim3(512, 1, 1);
	dim3 cuGrid = dim3(BlockCount, 1, 1);

    sfx_kernel::noit_compute_flux<T><<<cuGrid, cuBlock>>>(sfx, meteo, model, 
                                                    surface, numerics, phys, grid_size);

    if(MemType::GPU != memOut)
    {
        const size_t new_size = grid_size * sizeof(T);
        memproc::memcopy<memOut, MemType::GPU>((void*&)res_sfx->zeta, (void*&)sfx.zeta, new_size);
        memproc::memcopy<memOut, MemType::GPU>((void*&)res_sfx->Rib, (void*&)sfx.Rib, new_size);
        memproc::memcopy<memOut, MemType::GPU>((void*&)res_sfx->Re, (void*&)sfx.Re, new_size);
        memproc::memcopy<memOut, MemType::GPU>((void*&)res_sfx->B, (void*&)sfx.B, new_size);
        memproc::memcopy<memOut, MemType::GPU>((void*&)res_sfx->z0_m, (void*&)sfx.z0_m, new_size);
        memproc::memcopy<memOut, MemType::GPU>((void*&)res_sfx->z0_t, (void*&)sfx.z0_t, new_size);
        memproc::memcopy<memOut, MemType::GPU>((void*&)res_sfx->Rib_conv_lim, (void*&)sfx.Rib_conv_lim, new_size);
        memproc::memcopy<memOut, MemType::GPU>((void*&)res_sfx->Cm, (void*&)sfx.Cm, new_size);
        memproc::memcopy<memOut, MemType::GPU>((void*&)res_sfx->Ct, (void*&)sfx.Ct, new_size);
        memproc::memcopy<memOut, MemType::GPU>((void*&)res_sfx->Km, (void*&)sfx.Km, new_size);
        memproc::memcopy<memOut, MemType::GPU>((void*&)res_sfx->Pr_t_inv, (void*&)sfx.Pr_t_inv, new_size);
    }
}

template class FluxSheba<float, MemType::GPU, MemType::GPU, MemType::GPU>;
template class FluxSheba<float, MemType::GPU, MemType::GPU, MemType::CPU>;
template class FluxSheba<float, MemType::GPU, MemType::CPU, MemType::GPU>;
template class FluxSheba<float, MemType::CPU, MemType::GPU, MemType::GPU>;
template class FluxSheba<float, MemType::CPU, MemType::CPU, MemType::GPU>;
template class FluxSheba<float, MemType::CPU, MemType::GPU, MemType::CPU>;
template class FluxSheba<float, MemType::GPU, MemType::CPU, MemType::CPU>;