#include "../include/mesh.hpp"
#include "../include/matvec.hpp"
#include "MemoryProcessing.h"

#ifdef INCLUDE_CUDA
    #include "MemoryProcessing.cuh"
#endif
namespace icethermo
{
    template <typename NumType, MemType memtype>
    Mesh<NumType, memtype>::Mesh()
    {
        SubBuffer = nullptr;
        SubBuffer_size = 0;
    }

    template <typename NumType, MemType memtype>
    Mesh<NumType, memtype>::Mesh(int n_uniform_layers, NumType thickness)
    {
        if (n_uniform_layers <= 0)
        {
            THERMO_ERR("Number of layers in mesh should be greater than 1!");
        }

        NumType* thicknesses = new NumType[n_uniform_layers];
        
        for (int i =0; i < n_uniform_layers; ++i)
        {
            thicknesses[i] = thickness/n_uniform_layers;
        }

#ifdef INCLUDE_CUDA
        if(memtype == MemType::GPU)
        {
            NumType* dev_thicknesses;
            cells_thickness_size_t = 0;
            memproc::realloc<memtype>((void *&)(dev_thicknesses), cells_thickness_size_t, n_uniform_layers * sizeof(NumType));
            memproc::memcopy<MemType::GPU, MemType::CPU>(dev_thicknesses, thicknesses, cells_thickness_size_t);
            cells_thickness = std::make_shared<NumType*>(std::move(dev_thicknesses));
            delete[] thicknesses;
        }
        else
#endif
        {   
            cells_thickness_size_t = n_uniform_layers * sizeof(NumType);
            cells_thickness = std::make_shared<NumType*>(std::move(thicknesses));
        }

        SubBuffer      = nullptr;
        SubBuffer_size = 0;
    }

    template <typename NumType, MemType memtype>
    Mesh<NumType, memtype>::Mesh(NumType thickness): Mesh(10, thickness)
    {
        SubBuffer = nullptr;
        SubBuffer_size = 0;
    }

    template <typename NumType, MemType memtype>
    Mesh<NumType, memtype>::Mesh(const NumType* unit_segment_decomposition, const int unit_segment_decomposition_size, NumType thickness)
    {   
        NumType sum_decomp = sum_vec<NumType, memtype>(unit_segment_decomposition, unit_segment_decomposition_size);
        if (std::abs(sum_decomp - 1.0) > 1e-5)
        {
            THERMO_ERR("Unit segment decomposition of length 1.0 should be given!");
        }
        
        NumType* thicknesses;
        cells_thickness_size_t = 0;

        memproc::realloc<memtype>((void *&)(thicknesses), cells_thickness_size_t, unit_segment_decomposition_size * sizeof(NumType));
        memproc::memcopy<memtype, memtype>(thicknesses, unit_segment_decomposition, cells_thickness_size_t);
        mul_vec<NumType, memtype>(thicknesses, thickness, unit_segment_decomposition_size);
        cells_thickness = std::make_shared<NumType*>(std::move(thicknesses));

        SubBuffer      = nullptr;
        SubBuffer_size = 0;
    }

    template <typename NumType, MemType memtype>
    Mesh<NumType, memtype>::Mesh(const Mesh<NumType, memtype>& other)
    {   
        this->cells_thickness = std::make_shared<NumType*>(*(other.cells_thickness));
        this->cells_thickness_size_t = other.cells_thickness_size_t;

        for (auto item: other.single_data)
        {
            auto key = item.first;
            auto value = item.second;

            (this->single_data)[key] = 
            std::make_pair
            (
                std::make_shared<NumType>(*(value.first)),
                value.second
            );
        }

        for (auto item: other.cells_data)
        {
            auto key = item.first;
            auto value = item.second;

            (this->cells_data)[key] = 
            std::make_pair
            (
                std::make_shared<NumType*>(*(value.first)),
                value.second
            );
        }

        this->cells_data_size_t = std::map<std::string, size_t> (other.cells_data_size_t);

        for (auto item: other.nodes_data)
        {
            auto key = item.first;
            auto value = item.second;

            (this->nodes_data)[key] = 
            std::make_pair
            (
                std::make_shared<NumType*>(*(value.first)),
                value.second
            );
        }

        this->nodes_data_size_t = std::map<std::string, size_t> (other.nodes_data_size_t);

        SubBuffer      = nullptr;
        SubBuffer_size = 0;
    }

    template <typename NumType, MemType memtype>
    Mesh<NumType, memtype>::~Mesh()
    {   
        memproc::dealloc<memtype>((void*&)(*cells_thickness.get()), cells_thickness_size_t);

        for (auto item: this->cells_data)
        {
            auto value = item.second;
            memproc::dealloc<memtype>((void*&)(*(value.first).get()), cells_data_size_t[item.first]);
        }

        for (auto item: this->nodes_data)
        {
           auto value = item.second;
           memproc::dealloc<memtype>((void*&)(*(value.first).get()), cells_data_size_t[item.first]);
        }

        if(SubBuffer_size != 0)
        {
            delete[] SubBuffer;
            SubBuffer_size = 0;
        }
    }

    template <typename NumType, MemType memtype>
    int Mesh<NumType, memtype>::GetCellsNum() const
    {
        return int(cells_thickness_size_t / sizeof(NumType));
    }

    template <typename NumType, MemType memtype>
    int Mesh<NumType, memtype>::GetNodesNum() const
    {
        return int(cells_thickness_size_t / sizeof(NumType)) + 1;
    }

    template <typename NumType, MemType memtype>
    std::shared_ptr<NumType> Mesh<NumType, memtype>::CreateSingleData(const std::string& varname, bool visible)
    {
        if (single_data.count(varname) != 0)
        {
            THERMO_ERR("Variable \'" + varname + "\' already exists, could not create single variable!");
        }
        NumType zero_val = 0;
        single_data[varname] = {std::make_shared<NumType>(std::move(zero_val)), visible};
        return single_data[varname].first;
    }

    template <typename NumType, MemType memtype>
    std::shared_ptr<NumType*> Mesh<NumType, memtype>::CreateCellsData(const std::string& varname, bool visible)
    {
        if (cells_data.count(varname) != 0)
        {
            THERMO_ERR("Variable \'" + varname + "\' already exists, could not create cell variable!");
        }

        NumType* zero_vec; //zero_vec(cells_thickness->size());
        cells_data_size_t[varname] = 0;
        memproc::realloc<memtype>((void *&)(zero_vec), cells_data_size_t[varname], cells_thickness_size_t);

        cells_data[varname] = {std::make_shared<NumType*>(std::move(zero_vec)), visible};
        return cells_data[varname].first;
    }

    template <typename NumType, MemType memtype>
    std::shared_ptr<NumType*> Mesh<NumType, memtype>::CreateNodesData(const std::string& varname, bool visible)
    {
        if (nodes_data.count(varname) != 0)
        {
            THERMO_ERR("Variable \'" + varname + "\' already exists, could not create node variable!");
        }

        // make zero vector
        NumType* zero_vec; //zero_vec(cells_thickness->size());
        nodes_data_size_t[varname] = 0;
        memproc::realloc<memtype>((void *&)(zero_vec), nodes_data_size_t[varname], GetNodesNum() * sizeof(NumType));
        nodes_data[varname] = {std::make_shared<NumType*>(std::move(zero_vec)), visible};
        return nodes_data[varname].first;
    }

    template <typename NumType, MemType memtype>
    void Mesh<NumType, memtype>::DeleteSingleData(const std::string& varname)
    {
        if (single_data.count(varname) == 0)
        {
            return;
        }
        single_data[varname].first.reset();
        single_data.erase(varname);
    }

    template <typename NumType, MemType memtype>
    void Mesh<NumType, memtype>::DeleteCellsData(const std::string& varname)
    {
        if (cells_data.count(varname) == 0)
        {
            return;
        }
        // cells_data[varname].first.reset();
        memproc::dealloc<memtype>((void*&)(*(cells_data[varname].first).get()), cells_data_size_t[varname]);
        cells_data.erase(varname);
    }

    template <typename NumType, MemType memtype>
    void Mesh<NumType, memtype>::DeleteNodesData(const std::string& varname)
    {
        if (nodes_data.count(varname) == 0)
        {
            return;
        }
        // nodes_data[varname].first.reset();
        memproc::dealloc<memtype>((void*&)(*(nodes_data[varname].first).get()), nodes_data_size_t[varname]);
        nodes_data.erase(varname);
    }

    template <typename NumType, MemType memtype>
    std::shared_ptr<NumType> Mesh<NumType, memtype>::GetSingleData(const std::string& varname)
    {
        if (single_data.count(varname) == 0)
        {
            THERMO_ERR("There is no single variable: \'" + varname + "\' - can't get!");
        }
        return single_data[varname].first;
    }

    template <typename NumType, MemType memtype>
    std::shared_ptr<NumType*> Mesh<NumType, memtype>::GetCellsData(const std::string& varname)
    {
        if (cells_data.count(varname) == 0)
        {
            THERMO_ERR("There is no cell variable: \'" + varname + "\' - can't get!");
        }
        return cells_data[varname].first;
    }

    template <typename NumType, MemType memtype>
    std::shared_ptr<NumType*> Mesh<NumType, memtype>::GetNodesData(const std::string& varname)
    {
        if (nodes_data.count(varname) == 0)
        {
            THERMO_ERR("There is no node variable: \'" + varname + "\' - can't get!");
        }
        return nodes_data[varname].first;
    }

    template <typename NumType, MemType memtype>
    std::shared_ptr<NumType*> Mesh<NumType, memtype>::GetCellsThickness()
    {
        return cells_thickness;
    }


    template <typename NumType, MemType memtype>
    NumType Mesh<NumType, memtype>::GetTotalThickness() const
    {
        return sum_vec<NumType, memtype>(*cells_thickness.get(), cells_thickness_size_t / sizeof(NumType));
    }

    template <typename NumType, MemType memtype>
    void Mesh<NumType, memtype>::MuteSingleData(const std::string& varname)
    {
        if (single_data.count(varname) == 0)
        {
            return;
        }
        single_data[varname].second = false;
    }

    template <typename NumType, MemType memtype>
    void Mesh<NumType, memtype>::MuteCellData(const std::string& varname)
    {
        if (cells_data.count(varname) == 0)
        {
            return;
        }
        cells_data[varname].second = false;
    }

    template <typename NumType, MemType memtype>
    void Mesh<NumType, memtype>::MuteNodeData(const std::string& varname)
    {
        if (nodes_data.count(varname) == 0)
        {
            return;
        }
        nodes_data[varname].second = false;
    }

    template <typename NumType, MemType memtype>
    void Mesh<NumType, memtype>::UnmuteSingleData(const std::string& varname)
    {
        if (single_data.count(varname) == 0)
        {
            return;
        }
        single_data[varname].second = true;
    }

    template <typename NumType, MemType memtype>
    void Mesh<NumType, memtype>::UnmuteCellData(const std::string& varname)
    {
        if (cells_data.count(varname) == 0)
        {
            return;
        }
        cells_data[varname].second = true;
    }

    template <typename NumType, MemType memtype>
    void Mesh<NumType, memtype>::UnmuteNodeData(const std::string& varname)
    {
        if (nodes_data.count(varname) == 0)
        {
            return;
        }
        nodes_data[varname].second = true;
    }

    template <typename NumType, MemType memtype>
    void Mesh<NumType, memtype>::SaveTXT(const std::string& filename)
    {
        std::string filename_txt = filename + ".txt";
        std::fstream* ofs = new std::fstream;
        ofs->open(filename_txt, std::ios::out);

        if (!ofs->is_open())
        {
            THERMO_ERR("can't open file "+ filename_txt + " for logging!");
        }

        // cell thickness info
        int count = cells_thickness_size_t / sizeof(NumType);
        PullCPUArray(*cells_thickness.get(), count);
        *ofs << "### cells_thickness_array ###\n";

        for (int i = 0; i < count; i++)
        {
            if (i != count - 1)
            {
                *ofs << SubBuffer[i] << " "; 
            }
            else
            {
                *ofs << SubBuffer[i];
            }
        }
        *ofs << "\n";

        // single data
        *ofs << "#### Single data ###\n";
        for (auto item: single_data)
        {
            auto key = item.first;
            auto val = item.second;

            if (val.second)
            {
                *ofs << key + "\n";
                *ofs << *(val.first);
                *ofs << "\n";
            }
        }

        // cells data
        *ofs << "#### Cells data ###\n";
        for (auto item: cells_data)
        {
            auto key = item.first;
            auto val = item.second;

            if (val.second)
            {
                *ofs << key + "\n";
                count = cells_data_size_t[key] / sizeof(NumType);
                PullCPUArray(*val.first.get(), count);
                for (int i = 0; i < count; i++)
                {
                    if (i != count - 1)
                    {
                        *ofs << SubBuffer[i] << " "; 
                    }
                    else
                    {
                        *ofs << SubBuffer[i];
                    }
                }
                *ofs << "\n";
            }
        }

        // nodes data
        *ofs << "#### Nodes data ###\n";
        for (auto item: nodes_data)
        {
            auto key = item.first;
            auto val = item.second;

            if (val.second)
            {
                *ofs << key + "\n";
                count = nodes_data_size_t[key] / sizeof(NumType);
                PullCPUArray(*val.first.get(), count);
                for (int i = 0; i < count; i++)
                {
                    if (i != count - 1)
                    {
                        *ofs << SubBuffer[i] << " "; 
                    }
                    else
                    {
                        *ofs << SubBuffer[i];
                    }
                }
                *ofs << "\n";
            }
        }

        ofs->close(); 
        delete ofs;

        std::cout << "Mesh saved to \'" + filename_txt + "\'\n";
    }

    template <typename NumType, MemType memtype>
    void Mesh<NumType, memtype>::SaveTXT(const std::string& filename, int postscript)
    {
        std::string file = filename;
        std::stringstream ss;
        ss << std::setfill('0') << std::setw(5) << postscript;
        file += ss.str();
        SaveTXT(file);
    }

// #ifdef USE_JSON_OUTPUT

//     template <typename NumType, MemType memtype>
//     void Mesh<NumType, memtype>::SaveJSON(const std::string& filename) const
//     {
//         // create empty json object
//         json j;

//         // cells thickness
//         j["cells_thickness_array"] = *cells_thickness;

//         // single data
//         for (auto item: single_data)
//         {
//             auto key = item.first;
//             auto val = item.second;
//             if (val.second)
//             {
//                 j["single data"][key] = *(val.first);
//             }
//         }

//         // cells data
//         for (auto item: cells_data)
//         {
//             auto key = item.first;
//             auto val = item.second;
//             if (val.second)
//             {
//                 j["cells data"][key] = *(val.first);
//             }
//         }

//         // nodes data
//         for (auto item: nodes_data)
//         {
//             auto key = item.first;
//             auto val = item.second;

//             if (val.second)
//             {
//                 j["nodes data"][key] = *(val.first);
//             }
//         }

//         // write json object to file
//         std::string filename_json = filename + ".json";
//         std::fstream* ofs = new std::fstream;
//         ofs->open(filename_json, std::ios::out);

//         if (!ofs->is_open())
//         {
//             THERMO_ERR("can't open file "+ filename_json + " for logging!");
//         }

//         *ofs << std::setw(4) << j;

//         ofs->close(); 
//         delete ofs;

//         std::cout << "Mesh saved to \'" + filename_json + "\'\n";
//     }

// #endif

// #ifdef USE_JSON_OUTPUT

//     template <typename NumType, MemType memtype>
//     void Mesh<NumType, memtype>::SaveJSON(const std::string& filename, int postscript) const
//     {
//         std::string file = filename;
//         std::stringstream ss;
//         ss << std::setfill('0') << std::setw(5) << postscript;
//         file += ss.str();
//         SaveJSON(file);
//     }
// #endif

    template <typename NumType, MemType memtype>
    bool Mesh<NumType, memtype>::CheckCellsDataExistency(const std::string& varname) const
    {
        if (cells_data.count(varname) == 0)
        {
            return false;
        }
        else
        {
            return true;
        }
    }

    template <typename NumType, MemType memtype>
    bool Mesh<NumType, memtype>::CheckNodesDataExistency(const std::string& varname) const
    {
        if (nodes_data.count(varname) == 0)
        {
            return false;
        }
        else
        {
            return true;
        }
    }

    template <typename NumType, MemType memtype>
    bool Mesh<NumType, memtype>::CheckSingleDataExistency(const std::string& varname) const
    {
        if (single_data.count(varname) == 0)
        {
            return false;
        }
        else
        {
            return true;
        }
    }

    template <typename NumType, MemType memtype>
    void Mesh<NumType, memtype>::PullCPUArray(const NumType* array, const int n)
    {
        if(n * sizeof(NumType) > SubBuffer_size)
        {
            delete[] SubBuffer;
            SubBuffer_size = n * sizeof(NumType);
            SubBuffer = new NumType[n];
        }
        memproc::memcopy<MemType::CPU, memtype>(SubBuffer, array, n * sizeof(NumType));
    }

    // explicit instantiation of classes
    template class Mesh<float, MemType::CPU>;
    template class Mesh<double, MemType::CPU>;

#ifdef INCLUDE_CUDA
    template class Mesh<float, MemType::GPU>;
    template class Mesh<double, MemType::GPU>;
#endif
}