#pragma once
#include "sfx-template-parameters.h"
#include "model-base.h"
#include "sfx-data.h"

template<typename T, MemType memIn, MemType memOut, MemType RunMem >
class FluxEsmBase : public ModelBase<T, memIn, memOut, RunMem>
{
public:
    using ModelBase<T, memIn, memOut, RunMem>::res_sfx;
    using ModelBase<T, memIn, memOut, RunMem>::sfx;
    using ModelBase<T, memIn, memOut, RunMem>::meteo;
    using ModelBase<T, memIn, memOut, RunMem>::grid_size;
    using ModelBase<T, memIn, memOut, RunMem>::ifAllocated;
    using ModelBase<T, memIn, memOut, RunMem>::allocated_size;

    sfx_surface_param surface;
    sfx_phys_constants phys;
    sfx_esm_param_C model;
    sfx_esm_numericsType_C numerics;

    FluxEsmBase(sfxDataVecTypeC* sfx,
                meteoDataVecTypeC* meteo,
                const sfx_esm_param_C model, 
                const sfx_surface_param surface,
                const sfx_esm_numericsType_C numerics,
                const sfx_phys_constants phys,
                const int grid_size);
    ~FluxEsmBase();
};

template<typename T, MemType memIn, MemType memOut, MemType RunMem >
class FluxEsm : public FluxEsmBase<T, memIn, memOut, RunMem>
{};

template<typename T, MemType memIn, MemType memOut >
class FluxEsm<T, memIn, memOut, MemType::CPU> : public FluxEsmBase<T, memIn, memOut, MemType::CPU>
{
    using FluxEsmBase<T, memIn, memOut, MemType::CPU>::res_sfx;
    using FluxEsmBase<T, memIn, memOut, MemType::CPU>::sfx;
    using FluxEsmBase<T, memIn, memOut, MemType::CPU>::meteo;
    using FluxEsmBase<T, memIn, memOut, MemType::CPU>::surface;
    using FluxEsmBase<T, memIn, memOut, MemType::CPU>::phys;
    using FluxEsmBase<T, memIn, memOut, MemType::CPU>::grid_size;
    using FluxEsmBase<T, memIn, memOut, MemType::CPU>::ifAllocated;
    using FluxEsmBase<T, memIn, memOut, MemType::CPU>::allocated_size;
    using FluxEsmBase<T, memIn, memOut, MemType::CPU>::model;
    using FluxEsmBase<T, memIn, memOut, MemType::CPU>::numerics;
public:
    FluxEsm(sfxDataVecTypeC* sfx,
                meteoDataVecTypeC* meteo,
                const sfx_esm_param_C model, 
                const sfx_surface_param surface,
                const sfx_esm_numericsType_C numerics,
                const sfx_phys_constants phys,
                const int grid_size) : FluxEsmBase<T, memIn, memOut, MemType::CPU>(sfx, meteo, model, 
                                       surface, numerics, phys, grid_size) {}
    ~FluxEsm() = default;
    void compute_flux();
};

#ifdef INCLUDE_CUDA
template<typename T, MemType memIn, MemType memOut >
class FluxEsm<T, memIn, memOut, MemType::GPU> : public FluxEsmBase<T, memIn, memOut, MemType::GPU>
{
    using FluxEsmBase<T, memIn, memOut, MemType::GPU>::res_sfx;
    using FluxEsmBase<T, memIn, memOut, MemType::GPU>::sfx;
    using FluxEsmBase<T, memIn, memOut, MemType::GPU>::meteo;
    using FluxEsmBase<T, memIn, memOut, MemType::GPU>::surface;
    using FluxEsmBase<T, memIn, memOut, MemType::GPU>::phys;
    using FluxEsmBase<T, memIn, memOut, MemType::GPU>::grid_size;
    using FluxEsmBase<T, memIn, memOut, MemType::GPU>::ifAllocated;
    using FluxEsmBase<T, memIn, memOut, MemType::GPU>::allocated_size;
    using FluxEsmBase<T, memIn, memOut, MemType::GPU>::model;
    using FluxEsmBase<T, memIn, memOut, MemType::GPU>::numerics;
public:
    FluxEsm(sfxDataVecTypeC* sfx,
                meteoDataVecTypeC* meteo,
                const sfx_esm_param_C model, 
                const sfx_surface_param surface,
                const sfx_esm_numericsType_C numerics,
                const sfx_phys_constants phys,
                const int grid_size) : FluxEsmBase<T, memIn, memOut, MemType::GPU>(sfx, meteo, model, 
                                       surface, numerics, phys, grid_size) {}
    ~FluxEsm() = default;
    void compute_flux();
};
#endif