#include <cuda.h>
#include <cuda_runtime_api.h>

#include "sfx-esm.h"
#include "sfx-model-compute-subfunc.cuh"
#include "sfx-surface.cuh"
#include "sfx-memory-processing.cuh"

namespace sfx_kernel
{
    template<typename T>
    __global__ void compute_flux(sfxDataVecTypeC sfx,
        meteoDataVecTypeC meteo,
        const sfx_esm_param_C model, 
        const sfx_surface_param surface,
        const sfx_esm_numericsType_C numerics,
        const sfx_phys_constants phys,
        const int grid_size);
}

template<typename T>
__global__ void sfx_kernel::compute_flux(sfxDataVecTypeC sfx,
    meteoDataVecTypeC meteo,
    const sfx_esm_param_C model, 
    const sfx_surface_param surface,
    const sfx_esm_numericsType_C numerics,
    const sfx_phys_constants phys,
    const int grid_size)
{
    const int index = blockIdx.x * blockDim.x + threadIdx.x;
    T h, U, dT, Tsemi, dQ, z0_m;
    T Re, z0_t, B, h0_m, h0_t, u_dyn0, zeta, Rib, zeta_conv_lim, Rib_conv_lim, f_m_conv_lim, f_h_conv_lim, psi_m, psi_h, phi_m, phi_h, Km, Pr_t_inv, Cm, Ct;
    int surface_type;
    T fval;

    if(index < grid_size)
    {
        U = meteo.U[index];
        Tsemi = meteo.Tsemi[index];
        dT = meteo.dT[index];
        dQ = meteo.dQ[index];
        h = meteo.h[index];
        z0_m = meteo.z0_m[index];

        surface_type = z0_m < 0.0 ? surface.surface_ocean : surface.surface_land;

        if (surface_type == surface.surface_ocean)
        {
            get_charnock_roughness(z0_m, u_dyn0, U, h, surface, numerics.maxiters_charnock);
            h0_m = h / z0_m;
        }
        if (surface_type == surface.surface_land) 
        {
            h0_m = h / z0_m;
            u_dyn0 = U * model.kappa / logf(h0_m);
        }

        Re = u_dyn0 * z0_m / phys.nu_air;
        get_thermal_roughness(z0_t, B, z0_m, Re, surface, surface_type);

        h0_t = h / z0_t;
        Rib = (phys.g / Tsemi) * h * (dT + 0.61e0 * Tsemi * dQ) / (U*U);

        get_convection_lim(zeta_conv_lim, Rib_conv_lim, f_m_conv_lim, f_h_conv_lim, 
                            h0_m, h0_t, B, 
                            model);


        if (Rib > 0.0) 
        {
            Rib = sfx_math::min(Rib, model.Rib_max);
            get_psi_stable(psi_m, psi_h, zeta, Rib, h0_m, h0_t, B, model);

            fval = model.beta_m * zeta;
            phi_m = 1.0 + fval;
            phi_h = 1.0/model.Pr_t_0_inv + fval;
        }
        else if (Rib < Rib_conv_lim) 
        {
            get_psi_convection(psi_m, psi_h, zeta, Rib, h0_m, h0_t, B, zeta_conv_lim, f_m_conv_lim, f_h_conv_lim, model, numerics.maxiters_convection);

            fval = powf(zeta_conv_lim / zeta, 1.0/3.0);
            phi_m = fval / f_m_conv_lim;
            phi_h = fval / (model.Pr_t_0_inv * f_h_conv_lim);
        }
        else if (Rib > -0.001) 
        {
            get_psi_neutral(psi_m, psi_h, zeta, h0_m, h0_t, B, model);
        
            phi_m = 1.0;
            phi_h = 1.0 / model.Pr_t_0_inv;
        }
        else
        {
            get_psi_semi_convection(psi_m, psi_h, zeta, Rib, h0_m, h0_t, B, model, numerics.maxiters_convection);
            
            phi_m = powf(1.0 - model.alpha_m * zeta, -0.25);
            phi_h = 1.0 / (model.Pr_t_0_inv * sqrtf(1.0 - model.alpha_h_fix * zeta));
        }

        Cm = model.kappa / psi_m;
        Ct = model.kappa / psi_h;

        Km = model.kappa * Cm * U * h / phi_m;
        Pr_t_inv = phi_m / phi_h;

        sfx.zeta[index]         = zeta;
        sfx.Rib[index]          = Rib;
        sfx.Re[index]           = Re;
        sfx.B[index]            = B;
        sfx.z0_m[index]         = z0_m;
        sfx.z0_t[index]         = z0_t;
        sfx.Rib_conv_lim[index] = Rib_conv_lim;
        sfx.Cm[index]           = Cm;
        sfx.Ct[index]           = Ct;
        sfx.Km[index]           = Km;
        sfx.Pr_t_inv[index]     = Pr_t_inv;
    }
}

template<typename T, MemType memIn, MemType memOut >
void FluxEsm<T, memIn, memOut, MemType::GPU>::compute_flux()
{
    const int BlockCount = int(ceil(float(grid_size) / 1024.0));
    dim3 cuBlock = dim3(1024, 1, 1);
	dim3 cuGrid = dim3(BlockCount, 1, 1);

    sfx_kernel::compute_flux<T><<<cuGrid, cuBlock>>>(sfx, meteo, model, 
                                                    surface, numerics, phys, grid_size);

    if(MemType::GPU != memOut)
    {
        const size_t new_size = grid_size * sizeof(T);
        memproc::memcopy<memOut, MemType::GPU>((void*&)res_sfx->zeta, (void*&)sfx.zeta, new_size);
        memproc::memcopy<memOut, MemType::GPU>((void*&)res_sfx->Rib, (void*&)sfx.Rib, new_size);
        memproc::memcopy<memOut, MemType::GPU>((void*&)res_sfx->Re, (void*&)sfx.Re, new_size);
        memproc::memcopy<memOut, MemType::GPU>((void*&)res_sfx->B, (void*&)sfx.B, new_size);
        memproc::memcopy<memOut, MemType::GPU>((void*&)res_sfx->z0_m, (void*&)sfx.z0_m, new_size);
        memproc::memcopy<memOut, MemType::GPU>((void*&)res_sfx->z0_t, (void*&)sfx.z0_t, new_size);
        memproc::memcopy<memOut, MemType::GPU>((void*&)res_sfx->Rib_conv_lim, (void*&)sfx.Rib_conv_lim, new_size);
        memproc::memcopy<memOut, MemType::GPU>((void*&)res_sfx->Cm, (void*&)sfx.Cm, new_size);
        memproc::memcopy<memOut, MemType::GPU>((void*&)res_sfx->Ct, (void*&)sfx.Ct, new_size);
        memproc::memcopy<memOut, MemType::GPU>((void*&)res_sfx->Km, (void*&)sfx.Km, new_size);
        memproc::memcopy<memOut, MemType::GPU>((void*&)res_sfx->Pr_t_inv, (void*&)sfx.Pr_t_inv, new_size);
    }
}

template class FluxEsmBase<float, MemType::GPU, MemType::GPU, MemType::GPU>;
template class FluxEsmBase<float, MemType::GPU, MemType::GPU, MemType::CPU>;
template class FluxEsmBase<float, MemType::GPU, MemType::CPU, MemType::GPU>;
template class FluxEsmBase<float, MemType::CPU, MemType::GPU, MemType::GPU>;
template class FluxEsmBase<float, MemType::CPU, MemType::CPU, MemType::GPU>;
template class FluxEsmBase<float, MemType::CPU, MemType::GPU, MemType::CPU>;
template class FluxEsmBase<float, MemType::GPU, MemType::CPU, MemType::CPU>;

template class FluxEsm<float, MemType::GPU, MemType::GPU, MemType::GPU>;
template class FluxEsm<float, MemType::GPU, MemType::GPU, MemType::CPU>;
template class FluxEsm<float, MemType::GPU, MemType::CPU, MemType::GPU>;
template class FluxEsm<float, MemType::CPU, MemType::GPU, MemType::GPU>;
template class FluxEsm<float, MemType::CPU, MemType::CPU, MemType::GPU>;
template class FluxEsm<float, MemType::CPU, MemType::GPU, MemType::CPU>;
template class FluxEsm<float, MemType::GPU, MemType::CPU, MemType::CPU>;