import mechnet as mn
import my_plot_edges
from my_plot_edges import x_position, z_position, y_position

NX = 16
NY = 2
NZ = 16

def get_parametercollection(seed):
    par = mn.ParameterCollection()
    nx = NX # needs to == el**h
    ny = NY
    nz = NZ # need to be >=  h +2 + o_set (+2 bc bound cond)    
    n = nx*ny*nz
    par.set_N(n)
    par.set_Nx(nx)
    par.set_Ny(ny)
    par.set_Nz(nz)
    par.set_NxNy()
    #new parameters necessary for the structure
    #par.set_xhierarchicalelementsize(2) # log_el(nx) = h <-> nx = el**h 
    #par.set_yhierarchicalelementsize(2)
    #par.set_gapzoffset(1) # == o_set
    par.set_thresholdrng(seed)
    par.set_structurerng(seed +1_000_000)
    par.set_weibullparameter(1.5)
    par.set_scalingfactor(1.00)
    par.set_uniformupperdirichlet(1.0)
    return par

def construct_system():
    constructor = mn.cubic3D.Cubic3DFullConnectionConstructor()
    #constructor = mn.cubic3D.ShuffledHierarchicalGapsDecorator(constructor)
    constructor = mn.cubic3D.ScalingWeibullThresholdDecorator(constructor)

    boundaries = mn.general.EmptyBoundariesConstructor()
    boundaries = mn.cubic3D.FixedLowerBoundaryDecorator(boundaries)
    boundaries = mn.cubic3D.UniformDisplacedUpperBoundaryDecorator(boundaries)
    
    parametercollection = get_parametercollection(1001)
    network = constructor.start_construction(parametercollection)
    net_dict = network.edge_dict
    new_net_dict = cut_heart(net_dict)
    network.edge_dict = new_net_dict
    return network

def cut_heart(net_dict : dict) -> dict:
    new_dict = {}
    nx, ny, nz = NX, NY, NZ
    for key, value in net_dict.items():
        node_1, node_2 = [int(key_string) for key_string in key.split('_')]
        node_1_coordinates = [x_position(node_1, nx, ny, nz), y_position(node_1, nx, ny, nz), z_position(node_1,  nx, ny, nz)] 
        node_2_coordinates = [x_position(node_2, nx, ny, nz), y_position(node_2, nx, ny, nz), z_position(node_2,  nx, ny, nz)] 
        if is_in_shape(x=node_1_coordinates[0], z=node_1_coordinates[2]) and is_in_shape(x=node_2_coordinates[0], z=node_2_coordinates[2]):
            new_dict[key] = value
    return new_dict

def is_in_shape(x, z):
    if x <= NX*0.5:
        if z > -x +(NX*0.5) and z < -x + (NX*1.25) : 
            ans = True
        else:
            ans =  False
    elif x > NX*0.5:
        if z > x - (NX*0.5) and z < x + (NX*0.25): 
            ans = True
        else: 
            ans = False
    return ans

if __name__ == "__main__":
    network = construct_system()
    edge_dict = network.edge_dict
    filename_plot = f"visualize_heart_{NX}x{NY}x{NZ}"
    my_plot_edges.plot_edges_mechnet_network(edge_dict,  [NX,NY,NZ], filename=filename_plot)
    
    print("cutoff prev")