#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
This file is used to postprocess the data by Bapst et al. It relies on several
conventions used by Bapst to sample and format his data, so do not use this
to process other data without checking that it behaves as intended.
@author: Victor Bapst et al.
modified by Stefan Hiemer
"""

import numpy as np

from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.linear_model import ElasticNet
from sklearn.kernel_ridge import KernelRidge
from sklearn.metrics import r2_score
from sklearn.tree import DecisionTreeRegressor

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import Dense, BatchNormalization
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard
from tensorflow.keras import backend as K

def coeff_determination(y_true, y_pred):

    SS_res =  K.sum(K.square( y_true-y_pred ))
    SS_tot = K.sum(K.square( y_true - K.mean(y_true) ) )
    return ( 1 - SS_res/(SS_tot + K.epsilon()) )

def build_perceptron(units, activation):
    """
    Simple multilayer perceptron.

    units: list
    activation: activation function
    """

    model = keras.Sequential()

    for unit in units:
        model.add(Dense(unit, activation=activation))
        model.add(BatchNormalization())

    model.add(Dense(1))

    minimizer = Adam(learning_rate = 0.001,
                     beta_1 = 0.9,  beta_2 = 0.999,
                     amsgrad = False)
    model.compile(loss = 'mse', optimizer = minimizer,
                    metrics = ['mse',coeff_determination])

    return model

if __name__ == '__main__':

    # load data
    exp = np.loadtxt('DisDataExp.dat', dtype=float, delimiter='\t')
    sim = np.loadtxt('DisDataSim.dat', dtype=float, delimiter='\t')

    # throw out strain
    exp = exp[:,:3]
    sim = sim[:,:3]


    # unite both datasets
    train = sim.copy()
    test = exp.copy()

    # logarithm
    trainlog = np.log10(train.copy())
    testlog = np.log10(test.copy())

    # scale to unit variance and zero mean
    scaler = StandardScaler().fit(trainlog)
    trainlog_scaled = scaler.transform(trainlog)
    testlog_scaled = scaler.transform(testlog)

    # split into features and target
    xtrain = trainlog_scaled[:,:2]
    ytrain = trainlog_scaled[:,2]

    xtest = testlog_scaled[:,:2]
    ytest = testlog_scaled[:,2]

    r2 = []
    # feed to linear models
    lin = None
    for alpha in np.arange(0.1,10.1,0.1):
        for l1_ratio in np.arange(0,1.1,0.1):
            lin_regressor = ElasticNet(alpha=alpha,
                                       l1_ratio=l1_ratio,
                                       max_iter=100000,
                                       random_state=0).fit(xtrain,ytrain)

            if not lin:
                lin = lin_regressor
                r2_train = r2_score(ytrain,lin_regressor.predict(xtrain))
                r2_test = r2_score(ytest,lin_regressor.predict(xtest))

                print()
                print(alpha, l1_ratio, r2_train, r2_test)
                print()

            else:
                if r2_score(ytest,lin_regressor.predict(xtest)) > r2_test:
                    lin = lin_regressor
                    r2_train = r2_score(ytrain,lin_regressor.predict(xtrain))
                    r2_test = r2_score(ytest,lin_regressor.predict(xtest))

                    print()
                    print(alpha, l1_ratio, r2_train, r2_test)
                    print()

    print(lin.get_params())
    print(r2_train, r2_test)
    r2.append([r2_train, r2_test])

    # feed to kernel ridge regression
    kridge = None
    for alpha in np.logspace(-5,5,101):
        for gamma in np.logspace(-5,5,101):
            kridge_regressor = KernelRidge(alpha=alpha,
                                        gamma=gamma,
                                        kernel='rbf').fit(xtrain,ytrain)

            if not kridge:
                kridge = kridge_regressor
                r2_train = r2_score(ytrain,kridge_regressor.predict(xtrain))
                r2_test = r2_score(ytest,kridge_regressor.predict(xtest))

                print()
                print(alpha, gamma, r2_train, r2_test)
                print()

            else:
                if r2_score(ytest,kridge_regressor.predict(xtest)) > r2_test:
                    kridge = kridge_regressor
                    r2_train = r2_score(ytrain,kridge_regressor.predict(xtrain))
                    r2_test = r2_score(ytest,kridge_regressor.predict(xtest))

                    print()
                    print(alpha, gamma, r2_train, r2_test)
                    print()

    print(kridge.get_params())
    print(r2_train, r2_test)
    r2.append([r2_train, r2_test])

    # feed to tree model
    tree = None
    for max_depth in [2,4,8,16,32,64,None]:
        for min_samples_split in [2,4,8,16,32,64]:
            for min_samples_leaf in [2,4,8,16,32,64]:
                tree_regressor = DecisionTreeRegressor(max_depth=max_depth,
                                            min_samples_split=min_samples_split,
                                            min_samples_leaf=min_samples_leaf)\
                                            .fit(xtrain,ytrain)

                if not tree:
                    tree = tree_regressor
                    r2_train = r2_score(ytrain, tree_regressor.predict(xtrain))
                    r2_test = r2_score(ytest, tree_regressor.predict(xtest))

                    print()
                    print(max_depth, min_samples_split, min_samples_leaf,
                          r2_train, r2_test)
                    print()

                else:
                    if r2_score(ytest,tree_regressor.predict(xtest)) > r2_test:
                        tree = tree_regressor
                        r2_train = r2_score(ytrain,
                                            tree_regressor.predict(xtrain))
                        r2_test = r2_score(ytest,
                                           tree_regressor.predict(xtest))

                        print()
                        print(max_depth, min_samples_split, min_samples_leaf,
                              r2_train, r2_test)
                        print()

    print(tree.get_params())
    print(r2_train, r2_test)
    r2.append([r2_train, r2_test])

    # feed to keras perceptron
    perceptron = None
    for units in [[10,10],[10,10,10],[10,10,10,10]]:
        for activation in ["relu","sigmoid"]:
            model_callback = ModelCheckpoint(filepath = './models/tensorboard-logs/perceptron',
                                             save_weights_only = True,
                                             monitor = 'mse',
                                             mode = 'min',
                                             save_best_only = True,
                                             save_freq = 'epoch')
            tensorboard_callback = TensorBoard(log_dir= './models/tensorboard-logs/perceptron',
                                               histogram_freq = 0,
                                               write_graph = True,
                                               write_images = False,
                                               update_freq = 'epoch',
                                               profile_batch = 2,
                                               embeddings_freq = 0,
                                               embeddings_metadata = None)

            network = build_perceptron(units, activation)

            network.fit(xtrain, ytrain,
                        epochs = 1000,
                        batch_size = ytrain.shape[0],
                        shuffle = True,
                        validation_data = (xtest, ytest),
                        callbacks = [model_callback, tensorboard_callback])

            if not perceptron:
                perceptron = network
                r2_train = r2_score(ytrain,perceptron.predict(xtrain))
                r2_test = r2_score(ytest,perceptron.predict(xtest))
                _units = units
                _act = activation
                print()
                print(units, activation, r2_train, r2_test)
                print()

            else:
                if r2_score(ytest,network.predict(xtest)) > r2_test:
                    perceptron = network
                    r2_train = r2_score(ytrain,perceptron.predict(xtrain))
                    r2_test = r2_score(ytest,perceptron.predict(xtest))
                    _units = units
                    _act = activation
                    print()
                    print(units, activation, r2_train, r2_test)
                    print()

    print(_units, _act)
    print(r2_train, r2_test)
    r2.append([r2_train, r2_test])

    # make final predictions
    train_size = np.shape(ytrain)[0]
    test_size = np.shape(ytest)[0]

    ytrain_lin = lin.predict(xtrain).reshape((train_size,1))
    ytest_lin = lin.predict(xtest).reshape((test_size,1))

    ytrain_kridge = kridge.predict(xtrain).reshape((train_size,1))
    ytest_kridge = kridge.predict(xtest).reshape((test_size,1))

    ytrain_tree = tree.predict(xtrain).reshape((train_size,1))
    ytest_tree = tree.predict(xtest).reshape((test_size,1))

    ytrain_perceptron = perceptron.predict(xtrain).reshape((train_size,1))
    ytest_perceptron = perceptron.predict(xtest).reshape((test_size,1))

    # undo the standardization and the logarithm
    backscaler = StandardScaler().fit(np.reshape(trainlog[:,-1],(train_size,1)))

    ytrain_lin = 10**backscaler.inverse_transform(ytrain_lin)
    ytest_lin = 10**backscaler.inverse_transform(ytest_lin)

    ytrain_kridge = 10**backscaler.inverse_transform(ytrain_kridge)
    ytest_kridge = 10**backscaler.inverse_transform(ytest_kridge)

    ytrain_tree = 10**backscaler.inverse_transform(ytrain_tree)
    ytest_tree = 10**backscaler.inverse_transform(ytest_tree)

    ytrain_perceptron = 10**backscaler.inverse_transform(ytrain_perceptron)
    ytest_perceptron = 10**backscaler.inverse_transform(ytest_perceptron)

    # save to files
    np.savetxt("train_sim-nostrain-train.txt",
               np.concatenate((train,
                               ytrain_lin,
                               ytrain_kridge,
                               ytrain_tree,
                               ytrain_perceptron),axis=1),
               header = 'strain rate,dislocation density, yield stress, elastic net model, kernel ridge regression, decision tree, perceptron',
               delimiter=',')

    np.savetxt("test_sim-nostrain-train.txt",
               np.concatenate((test,
                               ytest_lin,
                               ytest_kridge,
                               ytest_tree,
                               ytest_perceptron),axis=1),
               header = 'strain rate,dislocation density, yield stress, elastic net model, kernel ridge regression, decision tree, perceptron',
               delimiter=',')
    print(r2)
