import os
from typing import List
from enum import Enum

import matplotlib.dates as mdates
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import f90nml

'''
Script for plotting. 
Prerequisites:
    python 3.9
    numpy, matplotlib, pandas (pip install -r requirements.txt)

Usage:
    python3 plot.py
    set model_data_paths, obs_data_path, pools in plot.py
'''

nml = f90nml.read('ui1_config.nml')
carbon_model_type = nml['config_namelist']['carbon_model_type']
station = nml['config_namelist']['station_name']

model_data_paths = [
      f'results/{carbon_model_type}/{station}_1.txt',
      f'results/{carbon_model_type}/{station}_2.txt',
      f'results/{carbon_model_type}/{station}_3.txt'
]

if (station == 'DAO3') or (station == 'DAO4'):
    model_data_paths = [
          f'results/{carbon_model_type}/{station}_1.txt',
          f'results/{carbon_model_type}/{station}_2.txt',
          f'results/{carbon_model_type}/{station}_3.txt',
          f'results/{carbon_model_type}/{station}_4.txt'
    ]

#if enviromental == station
obs_data_path = f'data/obs_data_{station}.csv'
#else
    #model_data_paths = [
    #'results/common.txt'
    #]
pools = {'Csoil' : 'Почва',
         'Csoilb': 'Почва типа b'
}
   
if (carbon_model_type == 'rothc'):
    pools = {'CDPM': 'Разлагаемый растительный материал',
             'CRPM': 'Устойчивый растительный материал',
             'CBIO': 'Пул микробной биомассы',
             'CHUM': 'Долгоживущий гумусовый пул',
             'CIOM': 'Пул инертного органического вещества'
    }
elif (carbon_model_type == 'socs'):
    pools = {'csoil1': 'Огранический углерод в почве',
             'csoil2': 'Минерализованный углерод в почве'
            }
elif (carbon_model_type == 'other'):
    pools = {'C1': 'Pool 1',
             'C2': 'Pool 2'
            }


def get_experiment_name(path):
    # Extract the base name from the path without extension (e.g., 'rostov_01' from 'results/rostov_01.txt')
    base_name = os.path.basename(path)
    file_name_without_extension = os.path.splitext(base_name)[0]
    experiment_name, opt = file_name_without_extension.split('_')
    return experiment_name, opt


def read_data(pools_filename: str, sep: str = ';'):
    experiment_name, opt = get_experiment_name(pools_filename)
    out_dir = os.path.dirname(pools_filename)
    if not os.path.exists(f'{out_dir}/{experiment_name}_{opt}/'):
        os.mkdir(f'{out_dir}/{experiment_name}_{opt}/')

    df = pd.read_csv(pools_filename, sep=sep, header=None)
    col_rename_dict = {0: 'date'}
    for i, pool in enumerate(pools.keys()):
        col_rename_dict[i + 1] = pool

    df.rename(columns=col_rename_dict, inplace=True)
    df['date'] = pd.to_datetime(df['date'])
    return df, experiment_name, opt, out_dir


def pools_plots(df: pd.DataFrame, out_path: str) -> None:
    for pool, name in pools.items():
        plt.rcParams['font.size'] = '12'
        plt.figure(figsize=(15, 6))
        ax = plt.gca()

        plt.plot(df['date'], df[pool], linewidth=2)

        date_min = df['date'].min() - pd.DateOffset(years=1)
        date_max = df['date'].max() + pd.DateOffset(years=1)

        ax.xaxis.set_major_locator(mdates.YearLocator())
        ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y'))
        ax.set_xlim([date_min, date_max])

        plt.title(f'{name} ({pool})')
        plt.xlabel('Время, годы')
        plt.ylabel(f'Содержание углерода в почве, кг/м^2')

        plt.grid(True)
        plt.xticks(rotation=45)
        plt.tight_layout()
        plt.savefig(f'{out_path}/{pool}.png')
        plt.close()


def pools_sum_plots(df: pd.DataFrame, experiment_name, opt, out_path: str) -> None:
    plt.rcParams['font.size'] = '12'
    plt.figure(figsize=(15, 6))
    ax = plt.gca()

    pools_sum = df[df.columns[1:]].sum(axis=1)

    plt.plot(df['date'], pools_sum, linewidth=2)

    date_min = df['date'].min() - pd.DateOffset(years=1)
    date_max = df['date'].max() + pd.DateOffset(years=1)

    ax.xaxis.set_major_locator(mdates.YearLocator())
    ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y'))
    ax.set_xlim([date_min, date_max])

    plt.title(f'{experiment_name}, {int(opt)}')
    plt.xlabel('Время, годы')
    plt.ylabel(f'Содержание углерода в почве, кг/м^2')

    plt.grid(True)
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.savefig(out_path)
    plt.close()


def pools_sums_plots(file_names: List[str], obs_data: pd.DataFrame, out_path: str) -> None:
    plt.rcParams['font.size'] = '12'
    fig, ax = plt.subplots(figsize=(15, 6))

    date_min = pd.to_datetime('2030-01-01')
    date_max = pd.to_datetime('1970-01-01')

    colors = plt.cm.viridis(np.linspace(0, 1, len(file_names)))  # Color map for lines and dots

    color_index = 0
    for i, file_name in enumerate(file_names):
        df, experiment_name, opt, _ = read_data(file_name)
        pools_sum = df[df.columns[1:]].sum(axis=1)
        ax.plot(df['date'], pools_sum, label=f'Результаты расчетов {experiment_name} {int(opt)}', linewidth=2,
                color=colors[color_index])
        date_min = min(date_min, df['date'].min())
        date_max = max(date_max, df['date'].max())
        color_index += 1

    color_index = 0
    for i, col in enumerate(obs_data.columns[1:]):
        ax.scatter(obs_data['date'],
                   obs_data[col],
                   color=colors[color_index],
                   label=f'Данные наблюдений {i + 1}',
                   s=50)
        color_index += 1
        if i + 1 >= len(file_names):
            break
    date_min -= pd.DateOffset(years=1)
    date_max += pd.DateOffset(years=1)

    ax.xaxis.set_major_locator(mdates.YearLocator())
    ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y'))
    ax.set_xlim([date_min, date_max])
    ax.set_xlabel('Время, годы')
    ax.set_ylabel(f'Содержание углерода в почве, кг/м^2')
    ax.grid(True)

    plt.tight_layout(pad=4.0)
    ax.legend(loc="upper center", bbox_to_anchor=(0.5, -0.2), fancybox=True, shadow=True, ncol=2)
    plt.xticks(rotation=45)
    plt.title('Динамика запасов углерода в почве')
    plt.savefig(out_path, format='png', bbox_inches='tight')
    plt.close()


def main():
    obs_data = pd.read_csv(obs_data_path)
    obs_data['date'] = pd.to_datetime(obs_data['date'])
    out_dir = ''
    experiment_name = ''

    for file_name in model_data_paths:
        print(f'processing {file_name} started')
        data, experiment_name, opt, out_dir = read_data(file_name)
        pools_plots(data, f'{out_dir}/{experiment_name}_{opt}')
        print(f'pools plots for {file_name} saved at {out_dir}/{experiment_name}_{opt}')
        pools_sum_plots(data, experiment_name, opt, f'{out_dir}/{experiment_name}_{opt}_sum.png')
        print(f'pools sum plot for {file_name} saved at {out_dir}/')
    out_path = f'{out_dir}/{experiment_name}_summary_plot.png'
    pools_sums_plots(model_data_paths, obs_data, out_path)
    print(f'processing all data is done')


if __name__ == '__main__':
    main()
