import h5py
import mechnet as mn
import matplotlib.pyplot as plt
import numpy as np
import os
import math

plt.rcParams.update({
    'font.size': 15,              # Default font size for all text
    'axes.labelsize': 15,         # Font size for axis labels
    'xtick.labelsize': 15,        # Font size for x-axis tick labels
    'ytick.labelsize': 15,        # Font size for y-axis tick labels
    'legend.fontsize': 15         # Font size for legend text
    })

def x_position(node, nx, ny):
    return (node%nx)

def y_position(node, nx, ny):
    return ((node//nx)%ny)

def z_position(node,nx, ny):
    return node//(nx*ny)

def get_broken_edges(filepath : str):
    filepath_root = os.getcwd()
    filepath_complete = filepath_root + filepath
    searchstringlist = ["simulationdata"]
    grouppaths = mn.datman.bundle_all_sub_groups(filepath_complete, searchstringlist) 
    broken_edges = []
    for grouppath in grouppaths:
        with h5py.File(filepath_complete) as file:
            current_broken_edges = np.array(file[grouppath + "/logged_data/broken_edges"])  
        broken_edges.append(current_broken_edges.tolist())
    return broken_edges

def get_vertical_edges_coordinates(edge_list, nx, ny):
    edges_coordinates_array = []
    for run in edge_list:
        run_list = []
        for edge in run:
            node_1, node_2 = edge
            node_1_coordinates = [x_position(node_1,nx,ny), y_position(node_1,nx,ny), z_position(node_1,nx,ny)]  
            node_2_coordinates = [x_position(node_2,nx,ny), y_position(node_2,nx,ny), z_position(node_2,nx,ny)]  
            if node_1_coordinates[2] != node_2_coordinates[2]:
                edge_coordinates_entry = [*node_1_coordinates, *node_2_coordinates]
                run_list.append(edge_coordinates_entry)
            
        edges_coordinates_array.append(run_list)
    return edges_coordinates_array

def get_data_dict(filepaths_dict : dict[str, str]) -> dict[str, list]:
    offset_data_dict = {}
    for realization, filepath in filepaths_dict.items():
        if isinstance(filepath, list):
            broken_edges = []
            for filepath_i in filepath:
                current_broken_edges = get_broken_edges(filepath_i)
                broken_edges.append(*current_broken_edges)
        else:
            broken_edges = get_broken_edges(filepath)
        offset_data_dict[realization] = broken_edges
    return offset_data_dict

def get_average_heights_dict(run_list : list[list[int]]):
    heights_dict = {}
    no_of_runs = len(run_list)
    for run, edge_coord_list in enumerate(run_list):
        run_list_new = []
        for edge_coord in edge_coord_list:
            z = min(edge_coord[2],edge_coord[5])
            if z not in heights_dict.keys():
                heights_dict[z] = [0 for _ in  range(no_of_runs)]
                heights_dict[z][run] += 1
            else:
                heights_dict[z][run] += 1
    average_heights_dict = {}
    for height, broken_edges in heights_dict.items():
        avg_no = np.mean(broken_edges) 
        y_error = np.std(broken_edges)
        average_heights_dict[height] = {"avg" : avg_no, "yerror" : y_error}
    return average_heights_dict
        


def flatten_dict(dict_to_flatten : dict[int, list[list[int]]]) -> dict[int, list[int]]:
    new_dict = {}
    for key, val_list in dict_to_flatten.items():
        flat_list = []
        for inner_list in val_list:
            flat_list.extend(inner_list)
        new_dict[key] = flat_list
    return new_dict

def set_matplotlib_dims(data_list, width = 3):
    nx = len(data_list)
    depth = math.ceil(nx / 3)
    return nx, width, depth

def get_index(i, graph_num, width):
    if graph_num <= width:
        index = i
    else:
        index = i//width, i%width
    return index

def plot_distribution_of_height(edges_coordinates_dict : dict[int, list[list[int, int, int, int, int, int]]]):
    graph_num, width, depth = set_matplotlib_dims(edges_coordinates_dict)
    fig, axs = plt.subplots(depth,width)
    fig.tight_layout(pad=1.0)
    edges_coordinates_dict_flat = flatten_dict(edges_coordinates_dict)
    for i, (realization, edges_coordinates) in  enumerate(edges_coordinates_dict_flat.items()): 
        index = get_index(i, graph_num, width)
        z_list = [ min(edge[2], edge[5]) for edge in edges_coordinates]
        z_min = min(z_list)
        z_max = max(z_list)
        z_range = z_max - z_min
        axs[index].set_title(f"{realization}")
        axs[index].hist(z_list, bins=z_range+1)
        axs[index].set_xlabel("z position")
        axs[index].set_ylabel("Broken edges")
    plt.show()

def plot_average_per_height(average_heights_dict):
    graph_num, width, depth = set_matplotlib_dims(average_heights_dict)
    fig, axs = plt.subplots(depth, width)
    
    fig.tight_layout(pad=1.0)
    for i, realization in enumerate(average_heights_dict.keys()):
        index = get_index(i, graph_num, width)
        for height, subdict in average_heights_dict[realization].items():
            #axs[index].plot(int(height), subdict['avg'], "_")
            axs[index].errorbar(int(height), subdict['avg'], yerr = subdict['yerror'], fmt='o' ,color = 'black')
            axs[index].set_xlabel("Height")
            axs[index].set_ylabel("Average number of broken edges")
            axs[index].set_title(f"{realization}")
    plt.show()
    fig.savefig("avg_broken_edges_at_height.png")

def get_current_toughness_dict(filepaths_dict : dict):
    output_dict = {}
    for offset, filepath in filepaths_dict.items():
        output_dict[offset] = {}
        filepath = filepath[0].split('/')[1] + "/strength"
        with open(filepath, 'r') as file:
            data_dict = {}
            data_list = []
            data_list = [line.rstrip() for line in file]
            data_list = [line.split(" ") for line in data_list]
            data_dict["ipavg"] = float(data_list[0][0])
            data_dict["ipsig"] = float(data_list[0][1])
            data_dict["tavg"] = float(data_list[1][0])
            data_dict["tsig"] = float(data_list[1][1])
            output_dict[offset] = data_dict
    return output_dict

def plot_current_and_toughness(filepaths_dict : dict):
    data_dict = get_current_toughness_dict(filepaths_dict)             
    fig, axs = plt.subplots(1,2) 
    for offset, subdict in data_dict.items():
        axs[0].errorbar(offset, subdict["ipavg"], yerr = subdict["ipsig"], fmt='o', color = 'black')
        axs[1].errorbar(offset, subdict["tavg"], yerr = subdict["tsig"], fmt='o', color = 'black')
        axs[0].set_xlabel("Offset")
        axs[1].set_xlabel("Offset")
        axs[0].set_ylabel("Peak current")
        axs[1].set_ylabel("Toughness")
    plt.show()
    fig.savefig("max_current_toughness_at_offset.png")
        

def plot_edges(edges_coordinates : [int, int, int, int, int, int] ):
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    for edge in edges_coordinates:
        x = [edge[0], edge[3]]
        y = [edge[1], edge[4]]
        z = [edge[2], edge[5]]
        ax.plot(x, y, z, linewidth=1)
    plt.show()



if __name__ == "__main__":
    filepaths_dict = { "hierarchical" : [ "/5_precracked/a_80/RTS_FDB_PARSOL_SIRC-FCVC_SOG-IVCG_HFBA_CPC-NIC_-_{number}.h5".format(number=i) for i in range(1000,1050)],
                       "density reference" : [ "/6_precrack_density_offset/a_80/RTS_FDB_PARSOL_SIRC-FCVC_SOG-IVCG_HFBA_CPC-NIC_-_{number}.h5".format(number=i) for i in range(1000,1050)],
                    }
    os.chdir("/FASTTEMP/p7/lpyka/hierarchical_interface")
    data_dict = get_data_dict(filepaths_dict)

#    broken_edges = []
#    broken_edges_2 = []
#    broken_edges_2 = get_broken_edges(filepath_2)
#    for filepath in filepaths:
#        current_broken_edges = get_broken_edges(filepath)
#        broken_edges.extend(current_broken_edges) 
    nx = 128
    ny = 128
    broken_edges_coordinates_dict = {}
    average_heights_dict = {}
    for realization, broken_edges in data_dict.items():
        broken_edges_coordinates_dict[realization] = get_vertical_edges_coordinates(broken_edges, nx, ny)
        average_heights_dict[realization] = get_average_heights_dict(broken_edges_coordinates_dict[realization])
    #plot_current_and_toughness(filepaths_dict)
    plot_average_per_height(average_heights_dict)
    
    #plot_distribution_of_height(broken_edges_coordinates_dict) 
    #plot_edges(broken_edges_coordinates)
    print("cutoff prev")