import numpy as np
import matplotlib.pyplot as plt
import os
import sys
import shutil
import natsort
import imageio
from .ProcData import ProcData
import copy
from matplotlib.cm import ScalarMappable

class Plotter:
    ani_pngs_dir = "plotter_lib_pngs/"

    def __init__(self):
        self.ani_period = float(1./20.)
        self.filename = None
        self.out = None
        self.oname = None
        self.var = None
        self.mval = None
        self.function = None
        self.file_column_names = None
        self.title = None
        self.file_data = None
        self.plt_max_val = None
        self.plt_min_val = None
        self.min_y = None
        self.max_y = None
        self.min_max_var_vals = None
        self.transpose = None

    def __check_arg_dim_equiv(self, args):
        if args.var != None and args.mval != None:
            if len(args.var) != len(args.mval):
                print("The count of var assumed to be equal to the count of mval")
                sys.exit(-1)

    def __get_variable_names(self):
        var_names = []

        for d in self.file_data:
            names = d.variable_names
            var_names.append(np.array(names, dtype=object))
            
            
        if(len(var_names) > 1):
            for i in range(1, len(var_names)):
                if np.equal(var_names[0], var_names[i]).any() != True:
                    print("All files must have the same variable names")
                    sys.exit(-1)
        elif len(var_names) == 0:
            print("Undefined variable names")
            sys.exit(-1)
                    
        self.file_column_names  = var_names[0]


    def set(self, args, **kwargs):
        self.__check_arg_dim_equiv(args)
        self.filename = args.filename
        self.ndim = args.ndim
        self.out = args.out
        self.oname = args.oname
        self.var = args.var
        self.mval = args.mval
        self.function = args.func
        self.title = args.title
        self.min_y = args.min_y
        self.max_y = args.max_y
        self.transpose = args.transpose

        self.if_manual_plot = kwargs.get('if_manual_plot', False)
        self.if_save_result = kwargs.get('if_save_result', True)

        pData = []
        
        for fname in self.filename:
            p = ProcData(fname)
            p.get_variable_names()
            pData.append(p)

        self.file_data = copy.deepcopy(pData) 
        self.__get_variable_names()

        if args.func != self.dump:

            if args.func == self.plot or args.func == self.ani_plot:
                self.ndim = 1
            elif args.func == self.plot_contour or args.func == self.ani_plot_contour:
                self.ndim = 2
            elif args.func == self.avg_plot:
                self.ndim = 3

            if self.var == None and len(self.file_column_names) != 0:
                self.var = [self.file_column_names[i] for i in range(self.ndim, len(self.file_column_names))]

            for p in self.file_data:
                p.process_file(self.ndim, self.var, self.mval)
            self.fig_count = len(self.var)

            if self.transpose == None:
                self.transpose = [False for _ in self.var]
    
    def __plot(self):
        os.system("mkdir -p " + self.out)
        x_name = self.file_column_names[0]

        fig = plt.figure()

        for i in range(len(self.var)):
            y_name = self.var[i]
            trnsp = self.transpose[i]
            if trnsp:
                plt.plot(self.file_data[0].data[y_name], self.file_data[0].data[x_name], linewidth=4)
            else:
                plt.plot(self.file_data[0].data[x_name], self.file_data[0].data[y_name], linewidth=4)


        plt.legend(self.var)
        plt.xlabel(x_name, fontsize=10, fontweight='bold')

        if self.if_manual_plot: plt.show()
        else: plt.close(fig)
        if self.if_save_result: fig.savefig(self.out + self.oname[0])

    def __get_min_max_ax(self):
        min_max_var_vals = {var : [] for var in self.var}

        for var in self.var:
            max_val = np.nanmax(self.file_data[0].data[var])
            min_val = np.nanmin(self.file_data[0].data[var])

            for data in self.file_data:
                maval = np.nanmax(data.data[var])
                mival = np.nanmin(data.data[var])

                if maval > max_val:
                    max_val = maval
                if mival < min_val:
                    min_val = mival

            min_max_var_vals[var] = np.array([min_val, max_val])

        max_vals = np.array([min_max_var_vals[var][1] for var in self.var])
        min_vals = np.array([min_max_var_vals[var][0] for var in self.var])

        max_val = np.max(max_vals)
        min_val = np.min(min_vals)

        self.min_max_var_vals = np.array([min_val, max_val])


    def __ani_plot(self):
        if self.if_save_result:
            png_names = []
            os.system("mkdir -p " + self.out)
            os.system("mkdir -p " + self.ani_pngs_dir)

            names = natsort.natsorted(self.filename,reverse=False)
            x_name = self.file_column_names[0]
            data_i = 0

            if self.max_y == None:
                max_val = self.min_max_var_vals[1]
            else:
                max_val = self.max_y[0]

            if self.min_y == None:
                min_val = self.min_max_var_vals[0]
            else:
                min_val = self.min_y[0]

            duration = self.ani_period * len(names)

            for datafile in names:
                fig = plt.figure()
                plt.ylim([min_val, max_val])

                for i in range(len(self.var)):
                    y_name = self.var[i]
                    trnsp = self.transpose[i]
                    if trnsp:
                        plt.plot(self.file_data[data_i].data[y_name], self.file_data[data_i].data[x_name], linewidth=4)
                    else:
                        plt.plot(self.file_data[data_i].data[x_name], self.file_data[data_i].data[y_name], linewidth=4)

                plt.legend(self.var)
                plt.xlabel(x_name, fontsize=10, fontweight='bold')

                figname = os.path.basename(datafile)
                plt.close(fig)
                fig.savefig(self.ani_pngs_dir + figname.split('.')[0] + '.png')

                name = self.ani_pngs_dir + figname.split('.')[0] + '.png'
                png_names.append(name)
                data_i = data_i + 1

            images = []
            for file_name in png_names:
                images.append(imageio.v2.imread(file_name))

            imageio.mimsave(self.oname[0], images, fps = 5, loop = 0)
            shutil.rmtree(self.ani_pngs_dir)

    def __plot_contour(self):
        os.system("mkdir -p " + self.out)
        x_name = self.file_column_names[0]
        y_name = self.file_column_names[1]

        if self.oname == None:
            fig_names = self.var
            fig_end = ".png"
        else:
            fig_names = self.oname
            fig_end = ""

        if self.max_y == None:
            max_val = {var:self.min_max_var_vals[var][1] for var in self.var}
        else:
            max_val = {var:self.max_y[i] for var, i in zip(self.var, list(range(len(self.var))))}

        if self.min_y == None:
            min_val = {var:self.min_max_var_vals[var][0] for var in self.var}
        else:
            min_val = {var:self.min_y[i] for var, i in zip(self.var, list(range(len(self.var))))}
        
        for i in range(self.fig_count):
            fig,ax=plt.subplots(1,1)
            
            if self.title == None:
                title = self.var[i]
            else:
                title = self.title

            X = self.file_data[0].data[x_name]
            Y = self.file_data[0].data[y_name]
            Z = self.file_data[0].data[self.var[i]]
            
            vmin = min_val[self.var[i]]; vmax = max_val[self.var[i]];
       
            levels = np.linspace(vmin, vmax, 25)
            if self.function == self.plot_diff:
                if self.transpose[i]:
                    cp = ax.contourf(Y, X, Z.T)
                else:
                    cp = ax.contourf(X, Y, Z)
            else:
                if self.transpose[i]:
                    cp = ax.contourf(Y, X, Z.T, vmin=vmin, vmax=vmax, levels=levels)
                else:
                    cp = ax.contourf(X, Y, Z, vmin=vmin, vmax=vmax, levels=levels)
            # fig.colorbar(ScalarMappable(norm=cp.norm, cmap=cp.cmap),ticks=range(min_val[self.var[i]], max_val[self.var[i]]))
            fig.colorbar(cp) # Add a colorbar to a plot
            ax.set_title(title)
            ax.set_xlabel(x_name)
            ax.set_ylabel(y_name)
            
            if self.if_manual_plot: plt.show()
            else: plt.close(fig)
            if self.if_save_result: fig.savefig(self.out + fig_names[i] + fig_end)

    def __get_min_max_bar(self):
        self.filename = natsort.natsorted(self.filename,reverse=False)
        self.min_max_var_vals = {var : [] for var in self.var}

        for var in self.var:
            max_val = np.nanmax(self.file_data[0].data[var])
            min_val = np.nanmin(self.file_data[0].data[var])

            for data in self.file_data:
                maval = np.nanmax(data.data[var])
                mival = np.nanmin(data.data[var])

                if maval > max_val:
                    max_val = maval
                if mival < min_val:
                    min_val = mival

            self.min_max_var_vals[var] = np.array([min_val, max_val])

    def __ani_plot_contour(self):
        if self.if_save_result:
            # png_names = {var:[] for var in self.var}
            os.system("mkdir -p " + self.out)
            os.system("mkdir -p " + self.ani_pngs_dir)

            x_name = self.file_column_names[0]
            y_name = self.file_column_names[1]

            if self.oname == None:
                fig_names = self.var
                fig_end = ".gif"
            else:
                fig_names = self.oname
                fig_end = ""

            X = self.file_data[0].data[x_name]
            Y = self.file_data[0].data[y_name]

            if self.max_y == None:
                max_val = {var:self.min_max_var_vals[var][1] for var in self.var}
            else:
                max_val = {var:self.max_y[i] for var, i in zip(self.var, list(range(len(self.var))))}

            if self.min_y == None:
                min_val = {var:self.min_max_var_vals[var][0] for var in self.var}
            else:
                min_val = {var:self.min_y[i] for var, i in zip(self.var, list(range(len(self.var))))}

            # print(max_val)

            duration = self.ani_period * len(self.filename)

            i = 0
            vartp = list(range(len(self.var)))
            for var, i in zip(self.var, vartp):

                if self.title == None:
                    title = var
                else:
                    title = self.title

                counter = 0
                png_names = []

                for data in self.file_data:
                    fig,ax=plt.subplots(1,1)
                    ax.set_title(title)
                    ax.set_xlabel(x_name)
                    ax.set_ylabel(y_name)

                    Z = data.data[var]
                    if self.transpose[i]:
                        cp = ax.contourf(Y, X, Z.T, vmin=min_val[var], vmax=max_val[var])
                    else:
                        cp = ax.contourf(X, Y, Z, vmin=min_val[var], vmax=max_val[var])
                    fig.colorbar(cp) # Add a colorbar to a plot
                    plt.close(fig)
                    figname = var + str(counter)
                    fig.savefig(self.ani_pngs_dir + figname + '.png')
                    name = self.ani_pngs_dir + figname + '.png'
                    png_names.append(name)
                    counter += 1
                
                images = []
                for file_name in png_names:
                    images.append(imageio.v2.imread(file_name))

                imageio.mimsave(self.out + fig_names[i] + fig_end, images, fps = 5, loop=0)
                i += 1
            
            shutil.rmtree(self.ani_pngs_dir)

    def __avg(self, data, var_name):
        cx = data['cx']
        cy = data['cy']
        cz = data['cz']

        flat_matrix_data = np.zeros((cx * cy * cz))
        flat_data = data[var_name].flatten()
    
        for k in range(cz):
            for j in range(cy):
                for i in range(cx):
                    flat_matrix_data[k * cy * cx + j * cx + i] = flat_data[k * cy * cx + j * cx + i]

        matrix_data = np.reshape(flat_matrix_data, (cx, cy, cz), order='F')
        avg_data = np.average(matrix_data, axis=(1, 0))
        return avg_data


    def __avg_plot(self):
        os.system("mkdir -p " + self.out)
        fig = plt.figure()
        x_name = self.file_column_names[2]
        x = self.file_data[0].data[x_name]

        for i in range(len(self.var)):
            var = self.var[i]
            avg_data = self.__avg(self.file_data[0].data, var[i])
            if self.transpose[i]:
                plt.plot(avg_data, x, linewidth=4)
            else:
                plt.plot(x, avg_data, linewidth=4)

        plt.legend(self.var)
        plt.xlabel(x_name, fontsize=10, fontweight='bold')

        if self.if_manual_plot: plt.show()
        else: plt.close(fig)
        if self.if_save_result: fig.savefig(self.out + self.oname[0])

    def __dump(self):
        for variable_name in self.file_column_names:
            print(variable_name, end=' ')
        print('\n')

    def __plot_diff(self):
        for name in self.file_column_names:
            if self.file_data[0].data[name].shape != self.file_data[1].data[name].shape:
                print("Data dimensions do not match")
                sys.exit(-1)
        
        diff = {}
        for name in self.var:
            diff[name] = self.file_data[0].data[name] - self.file_data[1].data[name]

        dim_variables = list(set(self.file_column_names) - set(self.var))

        for name in dim_variables:
            diff[name] = self.file_data[0].data[name]

        diffProcData = ProcData()
        diffProcData.data = diff
        self.file_data = [diffProcData]

        basename0 = os.path.basename(self.filename[0])
        basename1 = os.path.basename(self.filename[1])
        self.title = str(basename0) + ' - ' + str(basename1)

        if self.ndim == 1:
            self.__get_min_max_ax()
        elif self.ndim == 2:
            self.__get_min_max_bar()
        
        if self.ndim == 1:
            self.__plot()
        elif self.ndim == 2:
            self.__plot_contour()

    def dump(self):
        self.__dump()

    def plot(self):
        self.__get_min_max_ax()
        self.__plot()

    def ani_plot(self):
        self.__get_min_max_ax()
        self.__ani_plot()

    def avg_plot(self):
        self.__get_min_max_ax()
        self.__avg_plot()

    def plot_contour(self):
        self.__get_min_max_bar()
        self.__plot_contour()

    def ani_plot_contour(self):
        self.__get_min_max_bar()
        self.__ani_plot_contour()

    def plot_diff(self):
        self.__plot_diff()

    def get_data(self):
        return_data = [copy.deepcopy(data.data) for data in self.file_data]
        return return_data