Skip to content
Snippets Groups Projects
Plotter.py 7.19 KiB
Newer Older
数学の武士's avatar
数学の武士 committed
import numpy as np
import matplotlib.pyplot as plt
import os
import sys
import shutil
import natsort
import imageio
from .ProcData import ProcData
import copy

class Plotter:
    ani_pngs_dir = "plotter_lib_pngs/"

    def __init__(self):
        self.filename = ""
        self.out = ""
        self.oname = ""
        self.var = ""
        self.mval = ""
        self.function = ""
        self.variable_names = ""

    def __check_arg_dim_equiv(self, args):
        if args.var != None and args.mval != None:
            if len(args.var) != len(args.mval):
                print("The count of var assumed to be equal to the count of mval")
                sys.exit(-1)

    def __get_variable_names(self):
        var_names = []

        for d in self.data:
            names = d.variable_names
            var_names.append(np.array(names, dtype=object))
            
            
        if(len(var_names) > 1):
            for i in range(1, len(var_names)):
                if np.equal(var_names[0], var_names[i]).any() != True:
                    print("All files must have the same variable names")
                    sys.exit(-1)
        elif len(var_names) == 0:
            print("Undefined variable names")
            sys.exit(-1)
                    
        self.variable_names  = var_names[0]


    def set(self, args, **kwargs):
        self.__check_arg_dim_equiv(args)
        self.filename = args.filename

        pData = []
        
        for fname in self.filename:
            p = ProcData(fname)
            p.get_variable_names()
            pData.append(p)

        self.data = copy.deepcopy(pData)  
        self.__get_variable_names()

        if args.func != self.dump:
            self.out = args.out
            self.oname = args.oname
            self.var = args.var
            self.mval = args.mval
            self.function = args.func
            self.if_manual_plot = kwargs.get('if_manual_plot', False)
            self.if_save_result = kwargs.get('if_save_result', True)

            if args.func == self.plot or args.func == self.ani_plot:
                ndim = 1
                if self.var == None and len(self.variable_names) != 0:
                    self.var = [self.variable_names[i] for i in range(ndim, len(self.variable_names))]
                for p in self.data:
                    p.process_file(ndim, self.var, self.mval)

            elif args.func == self.plot_contour:
                ndim = 2
                if self.var == None and len(self.variable_names) != 0:
                    self.var = [self.variable_names[i] for i in range(ndim, len(self.variable_names))]
                for p in self.data:
                    p.process_file(ndim, self.var, self.mval)

            elif args.func == self.avg_plot:
                ndim = 3
                if self.var == None and len(self.variable_names) != 0:
                    self.var = [self.variable_names[i] for i in range(ndim, len(self.variable_names))]
                for p in self.data:
                    p.process_file(ndim, self.var, self.mval)

            if self.var == None:
                self.fig_count = len(self.variable_names) - ndim
            else:
                self.fig_count = len(self.var)
    
    def __plot_plt(self):
        os.system("mkdir -p " + self.out)
        x_name = self.variable_names[0]

        fig = plt.figure()

        for y_name in self.var:
            plt.plot(self.data[0].data[x_name], self.data[0].data[y_name], linewidth=4)

        plt.legend(self.var)
        plt.xlabel(x_name, fontsize=10, fontweight='bold')

        if self.if_manual_plot: plt.show()
        else: plt.close(fig)
        if self.if_save_result: fig.savefig(self.out + self.oname[0])

    def __ani_plot(self):
        if self.if_save_result:
            png_names = []
            os.system("mkdir -p " + self.out)
            os.system("mkdir -p " + self.ani_pngs_dir)

            names = natsort.natsorted(self.filename,reverse=False)
            x_name = self.variable_names[0]
            data_i = 0

            for datafile in names:
                fig = plt.figure()

                for y_name in self.var:
                    plt.plot(self.data[data_i].data[x_name], self.data[data_i].data[y_name], linewidth=4)

                plt.legend(self.var)
                plt.xlabel(x_name, fontsize=10, fontweight='bold')

                figname = os.path.basename(datafile)
                plt.close(fig)
                fig.savefig(self.ani_pngs_dir + figname.split('.')[0] + '.png')

                name = self.ani_pngs_dir + figname.split('.')[0] + '.png'
                png_names.append(name)
                data_i = data_i + 1

            images = []
            for file_name in png_names:
                images.append(imageio.v2.imread(file_name))

            imageio.mimsave(self.oname[0], images, fps = 5)
            shutil.rmtree(self.ani_pngs_dir)

    def __plot_contour(self):
        os.system("mkdir -p " + self.out)
        x_name = self.variable_names[0]
        y_name = self.variable_names[1]

        if self.oname == None:
            fig_names = self.var
            fig_end = ".png"
        else:
            fig_names = self.oname
            fig_end = ""
        
        for i in range(self.fig_count):
            fig,ax=plt.subplots(1,1)
            title = self.var[i]

            X = self.data[0].data[x_name]
            Y = self.data[0].data[y_name]
            Z = self.data[0].data[self.var[i]]
            
            cp = ax.contourf(X, Y, Z)
            fig.colorbar(cp) # Add a colorbar to a plot
            ax.set_title(title)
            ax.set_xlabel(x_name)
            ax.set_ylabel(y_name)
            
            if self.if_manual_plot: plt.show()
            else: plt.close(fig)
            if self.if_save_result: fig.savefig(self.out + fig_names[i] + fig_end)

    def __avg(self, data, var_name):
        cx = data['cx']
        cy = data['cy']
        cz = data['cz']

        flat_matrix_data = np.zeros((cx * cy * cz))
        flat_data = data[var_name].flatten()
    
        for k in range(cz):
            for j in range(cy):
                for i in range(cx):
                    flat_matrix_data[k * cy * cx + j * cx + i] = flat_data[k * cy * cx + j * cx + i]

        matrix_data = np.reshape(flat_matrix_data, (cx, cy, cz), order='F')
        avg_data = np.average(matrix_data, axis=(1, 0))
        return avg_data


    def __avg_plot(self):
        os.system("mkdir -p " + self.out)
        fig = plt.figure()
        x_name = self.variable_names[2]
        x = self.data[0].data[x_name]

        for var in self.var:
            avg_data = self.__avg(self.data[0].data, var)
            plt.plot(x, avg_data, linewidth=4)

        plt.legend(self.var)
        plt.xlabel(x_name, fontsize=10, fontweight='bold')

        if self.if_manual_plot: plt.show()
        else: plt.close(fig)
        if self.if_save_result: fig.savefig(self.out + self.oname[0])

    def __dump(self):
        for variable_name in self.variable_names:
            print(variable_name, end=' ')
        print('\n')

    def dump(self):
        self.__dump()

    def plot(self):
        self.__plot_plt()

    def ani_plot(self):
        self.__ani_plot()

    def avg_plot(self):
        self.__avg_plot()

    def plot_contour(self):
        self.__plot_contour()