import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np

def x_position(node, nx, ny, nz):
    return (node%nx)

def y_position(node, nx, ny, nz):
    return ((node//nx)%ny)

def z_position(node,nx, ny, nz):
    return node//(nx*ny)


def get_edges_array(network_dict : dict, nxnynz : [int,int,int]):
    nx,ny,nz = nxnynz
    edges_array = []
    for key in network_dict.keys():
       node_1, node_2 = [int(key_string) for key_string in key.split('_')]
       node_1_coordinates = [x_position(node_1, nx, ny, nz), y_position(node_1, nx, ny, nz), z_position(node_1,  nx, ny, nz)] 
       node_2_coordinates = [x_position(node_2, nx, ny, nz), y_position(node_2, nx, ny, nz), z_position(node_2,  nx, ny, nz)] 
       line_entry = [*node_1_coordinates,*node_2_coordinates]
       edges_array.append(line_entry)
    return edges_array


def plot_edges_mechnet_network(network_dict : dict, nxnynz : [int,int,int], filename ="plot") -> None: 
    plt.rcParams.update({
    'font.size': 15,              # Default font size for all text
    'axes.labelsize': 15,         # Font size for axis labels
    'xtick.labelsize': 15,        # Font size for x-axis tick labels
    'ytick.labelsize': 15,        # Font size for y-axis tick labels
    'legend.fontsize': 15         # Font size for legend text
    })
    # Example data: Edges with values
    #edges = [
    #    (0, 0, 0, 1, 0, 0, 1.0),
    #    (1, 0, 0, 1, 1, 0, 1.0),
    #    (1, 1, 0, 0, 1, 0, 1.0),
    #    (0, 1, 0, 0, 1, 1, 4.0),
    #]

    edges = get_edges_array(network_dict, nxnynz)

    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    
    # Plot each edge
    for edge in edges:
        x = [edge[0], edge[3]]
        y = [edge[1], edge[4]]
        z = [edge[2], edge[5]]
        ax.plot(x, y, z, color = 'r', linewidth=2)
    ax.set_xlabel(r'$L_{x}$')
    ax.set_xlim(0,16)
    ax.set_ylabel(r'$L_{y}$')
    ax.set_ylim(-8,8)
    ax.set_zlabel(r'$L_{z}$')
    ax.set_zlim(0,16)
    fig.savefig(f'{filename}.png')

    
    # Add a color bar
    plt.show()