import json
from random import shuffle
import os
import sys

import numpy as np
#from scipy.stats import truncnorm, beta, uniform

SURFACE = 0
ROAD = -1
TREE = -2

ROAD_COLOR = (64, 61, 57)
SURFACE_COLOR = (96, 108, 56)
TREE_COLOR = (58, 90, 64)

def uniform_(low, upper):
    return uniform(low, upper - low)

class Splitter:
    def __init__(self, min_distance, min_block_width, rng):
        self.min_distance = min_distance
        self.min_block_width = min_block_width
        self.rng = rng

        self.split_distance = self.min_block_width * 2 + self.min_distance + 1

    def gen(self, x1, y1, x2, y2):
        self.blocks = []
        self.gen_blocks(x1, y1, x2, y2)
        return self.blocks

    def gen_blocks(self, top_left_x, top_left_y, bottom_right_x, bottom_right_y):
        width = bottom_right_x - top_left_x
        height = bottom_right_y - top_left_y

        if width > self.split_distance and height > self.split_distance:
            split_x = self.rng.integers(top_left_x + self.min_block_width + 1, bottom_right_x - self.min_block_width - self.min_distance)
            split_y = self.rng.integers(top_left_y + self.min_block_width + 1, bottom_right_y - self.min_block_width - self.min_distance)

            self.gen_blocks(top_left_x, top_left_y, split_x, split_y)
            self.gen_blocks(split_x + self.min_distance, top_left_y, bottom_right_x, split_y)
            self.gen_blocks(top_left_x, split_y + self.min_distance, split_x, bottom_right_y)
            self.gen_blocks(split_x + self.min_distance, split_y + self.min_distance, bottom_right_x, bottom_right_y)

        elif width > self.split_distance and height >= self.min_block_width:
            split_x = self.rng.integers(top_left_x + self.min_block_width + 1, bottom_right_x - self.min_block_width - self.min_distance)

            self.gen_blocks(top_left_x, top_left_y, split_x, bottom_right_y)
            self.gen_blocks(split_x + self.min_distance, top_left_y, bottom_right_x, bottom_right_y)

        elif height > self.split_distance and width >= self.min_block_width:
            split_y = self.rng.integers(top_left_y + self.min_block_width + 1, bottom_right_y - self.min_block_width - self.min_distance)

            self.gen_blocks(top_left_x, top_left_y, bottom_right_x, split_y)
            self.gen_blocks(top_left_x, split_y + self.min_distance, bottom_right_x, bottom_right_y)
            
        else: # width <= self.split_width and height <= self.split_width:
            self.blocks.append((top_left_x, top_left_y, bottom_right_x, bottom_right_y))

class LCZ:
    def __init__(self, config_path='config.json', output_folder='.', seed=None):
        self.rng = np.random.default_rng(seed)
        self.config = self.load_config(config_path)
        self.output_folder = output_folder

        if not os.path.exists(output_folder):
            os.mkdir(output_folder)

        self.width = self.config['width']
        self.height = self.config['height']

        self.total_area = self.config['width'] * self.config['height']
        self.config_building_area = self.total_area * self.config['building_surface_fraction']
        self.mean_street_width = self.config['mean_height_of_roughness_elements'] / self.config['aspect_ratio'] # == ширина улицы, так как все улицы имеют одинаковую ширину

        self.config_tree_area = self.total_area * self.config['pervious_surface_fraction']

        self.min_building_height = self.config['min_building_height']
        self.max_building_height = self.config['max_building_height']

        self.possible_buildings = []

        self.n_buildings = 0
        self.real_building_area = 0
        self.buildings = []

        self.n_trees = 0
        self.real_tree_area = 0
        self.trees = []

        self.building_heights = np.array([])
        self.tree_heights = np.array([])

        self.min_building_distance = max(int(self.mean_street_width), 1)
        max_n_street = (self.total_area - self.config_building_area) / (self.mean_street_width * min(self.width, self.height)) - 1
        
        self.min_building_width =  2 * (int(self.config_building_area / (min(self.width, self.height) * max_n_street))) # int((max(self.width, self.height) - max_n_street * self.mean_street_width) / max_n_street)
        self.split_width = self.min_building_width * 2 + self.min_building_distance + 1

        self.gen_buildings_v2()

    def check_params(self):
        real_mean_height = np.mean(self.building_heights)
        real_std_height = np.std(self.building_heights)
        real_max_building_height = np.max(self.building_heights)
        real_min_building_height = np.min(self.building_heights)

        real = f"""
        generated LCZ params:
        std of height of roughness elements: {real_std_height}
        mean height of roughness elements: {real_mean_height}
        max height of roughness elements: {real_max_building_height}
        min height of roughness elements: {real_min_building_height}
        aspect ratio: {real_mean_height / self.mean_street_width}
        building surface fraction: {self.real_building_area / self.total_area}
        """

        desired = f"""
        desired LCZ params:
        std of height of roughness elements: {self.config['standard_deviation_of_roughness_elements']}
        mean height of roughness elements: {self.config['mean_height_of_roughness_elements']}
        max height of roughness elements: {self.config['max_building_height']}
        min height of roughness elements: {self.config['min_building_height']}
        aspect ratio: {self.config['aspect_ratio']}
        building surface fraction: {self.config['building_surface_fraction']}
        """

        print(real)
        print(desired)

    def gen_building_heights(self):
        self.building_heights = self.rng.uniform(self.min_building_height, self.max_building_height + 1, self.n_buildings)

    def gen_buildings(self):
        splitter = Splitter(self.min_building_distance, self.min_building_width, self.rng)
        possible_buildings = splitter.gen(self.min_building_distance, self.min_building_distance, self.width - self.min_building_distance, self.height - self.min_building_distance)

        shuffle(possible_buildings)

        while self.real_building_area < self.config_building_area:
            building = possible_buildings.pop()
            self.buildings.append(building)

            self.real_building_area += self.area(building)
            self.n_buildings += 1

            if not possible_buildings:
                print("end of area for buildings")
                break

        self.gen_building_heights()

    def gen_buildings_v2(self):
        self.main_road = int((3/2) * self.min_building_distance)
        self.road = 2 * self.min_building_distance - self.main_road

        splitter_big = Splitter(self.main_road, self.width // 5, self.rng)

        blocks = splitter_big.gen(self.min_building_distance, self.min_building_distance, self.width - self.min_building_distance, self.height - self.min_building_distance)

        splitter_small = Splitter(self.road, self.min_building_width, self.rng)

        possible_buildings = []

        for block in blocks:
            possible_buildings += splitter_small.gen(*block)

        shuffle(possible_buildings)

        self.possible_buildings = possible_buildings[:]

        while self.real_building_area < self.config_building_area:
            building = possible_buildings.pop()
            self.buildings.append(building)

            self.real_building_area += self.area(building)
            self.n_buildings += 1

            if not possible_buildings:
                print("end of area for buildings")
                break

        # while self.real_tree_area < self.config_tree_area:
        #     tree = possible_buildings.pop()
        #     self.trees.append(tree)

        #     self.real_tree_area += self.area(tree)
        #     self.n_trees += 1

        #     if not possible_buildings:
        #         print("end of area for trees")
        #         print(self.trees)
        #         break

        self.gen_building_heights()

    def area(self, bbox):
        return (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])

    def add_building(self, lcz, bbox, height=2):
        lcz[bbox[1]:bbox[3], bbox[0]:bbox[2]] = height

    def add_tree(self, lcz, bbox, height=2):
        dx, dy = (bbox[2] - bbox[0]) / 2, (bbox[3] - bbox[1]) / 2
        r = min([dx, dy])
        cx = bbox[0] + dx
        cy = bbox[1] + dy
        yy, xx = np.mgrid[:self.height, :self.width]
        circle = (xx - cx) ** 2 + (yy - cy) ** 2 <= r ** 2
        lcz[circle] = height

    def load_config(self, config_path):
        return json.load(open(config_path, 'r', encoding='utf-8'))

    def put_buildings(self, lcz):
        for i in range(self.n_buildings):
            self.add_building(lcz, self.buildings[i], self.building_heights[i])

    def put_trees(self, lcz):
        for i in range(self.n_trees):
            self.add_tree(lcz, self.trees[i], height=16)

    def to_model_input_building(self, filename='lcz.txt'):
        lcz = self.to_height_map()
        with open(os.path.join(self.output_folder, filename), "w", encoding='utf-8') as file:
            file.write(f"{self.height} {self.width}\n")
            file.write("\n".join([" ".join(list(map(str, row))) for row in lcz]))

    def to_model_input_tree(self, filename='lcz.txt'):
        lcz = self.to_height_map(False, True)
        with open(os.path.join(self.output_folder, filename), "w", encoding='utf-8') as file:
            file.write(f"{self.height} {self.width}\n")
            file.write("\n".join([" ".join(list(map(str, row))) for row in lcz]))

    def to_model_input_road(self, filename='lcz.txt'):
        lcz = np.ones((self.height, self.width), dtype=int)
        for bbox in self.possible_buildings:
            self.add_building(lcz, bbox, 0)
        with open(os.path.join(self.output_folder, filename), "w", encoding='utf-8') as file:
            file.write(f"{self.height} {self.width}\n")
            file.write("\n".join([" ".join(list(map(str, row))) for row in lcz]))

    def to_height_map(self, dtype=np.int64, building=True, tree=False):
        lcz = np.zeros((self.height, self.width), dtype=dtype)
        if building:
            self.put_buildings(lcz)
        if tree:
            self.put_trees(lcz)
        return lcz

    def model_input_to_height_map(self, filename='map.txt'):
        height_map = np.array([list(map(int, row.split())) for row in open(filename).read().split("\n")])
        return height_map

if __name__ == '__main__':
    import sys

    if sys.argc > 1:
        t = sys.argv[1]
        lcz = LCZ(config_path=f'configs/{t}.json')
        lcz.to_model_input_building('map.txt')
        sys.exit(0)

    lcz_types = ['compact_high_rise', 'compact_mid_rise', 'compact_low_rise', 'open_high_rise', 'open_mid_rise', 'open_low_rise', 'lightweight_low_rise', 'large_low_rise', 'sparsley_build', 'heavy_industry']

    fig, axs = plt.subplots(2, 2, figsize=(10, 6))

    lcz_type = "open_mid_rise"

    out_folder = "."

    for i, ax in enumerate(axs.flat):
        lcz = LCZ(config_path=f'configs/{lcz_type}.json', output_folder=out_folder)

        lcz.to_model_input_building(f'lcz_{i + 1}_building.txt')
        lcz.to_model_input_tree(f'lcz_{i + 1}_tree.txt')
        lcz.to_model_input_road(f'lcz_{i + 1}_road.txt')
        lcz.check_params()

        height_map = lcz.to_height_map()
        im = ax.imshow(height_map, cmap='viridis', origin='lower')
        ax.set_title(f"{lcz_type} {i}")
        fig.colorbar(im, ax=ax, shrink=0.3)

    plt.tight_layout()
    plt.savefig(f'{out_folder}/height_maps_mid_rise.png')
    plt.show()