import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.neighbors import NearestNeighbors
from frequency_based_random_sampling import FrequencyBasedRandomSampling
from alibi.explainers import ALE
from encoding_utils import *

class ExplanationBasedNeighborhood():
    def __init__(self,
                 X,
                 y,
                 model,
                 dataset):

        # splitting the data into train and test set with the same random state used for training the model
        X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

        # check whether the training data contains all possible values for the features; add extra samples in case
        for f in range(X_train.shape[1]):
            for fv in dataset['feature_values'][f]:
                if fv in np.unique(X_train[:,f]):
                    pass
                else:
                    idx = np.where(X_test[:, f] == fv)[0][0]
                    X_train = np.r_[X_train, X_test[idx, :].reshape(1,-1)]
                    y_train = np.r_[y_train, y_test[idx]]

        self.X_train = X_train
        self.y_train = model.predict(X_train)
        self.model = model
        self.dataset = dataset
        self.discrete_indices = dataset['discrete_indices']
        self.class_set = np.unique(y_train)

    def categoricalSimilarity(self):

        # initializing the variables
        categorical_similarity = {}
        categorical_width = {}
        categorical_importance = {}
        for c in self.class_set:
            categorical_similarity.update({c: {}})
            categorical_width.update({c: {}})
            categorical_importance.update({c: {}})

        # creating ALE explainer
        ale_explainer = ALE(self.model.predict_proba,
                            feature_names=self.discrete_indices,
                            target_names=self.class_set,
                            low_resolution_threshold=100)
        ale_exp = ale_explainer.explain(self.X_train)

        # extracting global effect values
        for c in self.class_set:
            for f in self.discrete_indices:
                categorical_similarity[c][f] = pd.Series(ale_exp.ale_values[f][:,c])
                categorical_width[c][f] = max(ale_exp.ale_values[f][:,c]) - min(ale_exp.ale_values[f][:,c])
                categorical_importance[c][f] = max(ale_exp.ale_values[f][:,c])

        # returning the results
        self.categorical_similarity = categorical_similarity
        self.categorical_width = categorical_width
        self.categorical_importance = categorical_importance

    def neighborhoodModel(self):

        # creating neighborhood models based on class-wise ground-truth data
        class_data = {}
        for c in self.class_set:
            class_data.update({c: {}})

        class_data = {}
        models = {}
        for c in self.class_set:
            ind_c = np.where(self.y_train == c)[0]
            X_c = self.X_train[ind_c, :]
            class_data[c] = X_c
            X_c_ohe = ord2ohe(X_c, self.dataset)
            model = NearestNeighbors(n_neighbors=1, algorithm='ball_tree', metric='matching')
            model.fit(X_c_ohe)
            models[c] = model
        self.class_data = class_data
        self.neighborhood_models = models

    def fit(self):
        self.categoricalSimilarity()
        self.neighborhoodModel()

    def cat2numConverter(self,
                          x,
                          feature_list=None,
                          label = None):

        # converting features in categorical representation to explanation representation
        if feature_list == None:
            feature_list = self.discrete_indices

        x_num = x.copy()
        if x_num.shape.__len__() == 1:
            # the input is a single instance
            if label is None:
                label = self.model.predict(x.reshape(1,-1))[0]
            for f in feature_list:
                x_num[f] = self.categorical_similarity[label][f][x[f]]
        else:
            # the input is a matrix of instances
            labels = self.model.predict(x)
            for f in feature_list:
                vec = x[:,f]
                vec_converted = np.asarray(list(map(lambda c,v: self.categorical_similarity[c][f][v], labels, vec)))
                x_num[:,f] = vec_converted
        return x_num

    def neighborhoodSampling(self, x, N_samples):

        # finding the label of x
        x_c = self.model.predict(x.reshape(1,-1))[0]

        # finding the closest neighbors in the other classes
        R = {}
        x_ohe = ord2ohe(x, self.dataset)
        for c in self.class_set:
            if c == x_c:
                R[c] = x
            else:
                distances, indices = self.neighborhood_models[c].kneighbors(x_ohe.reshape(1, -1))
                R[c] = self.class_data[c][indices[0][0]].copy()

        # converting input samples from categorical to numerical (global feature effects) representation
        R_num = {}
        for c, x_counterpart in R.items():
            R_num[c] = self.cat2numConverter(x_counterpart)

        # distance from x to counterparts in numerical (global feature effects) representation
        distance_representative = {}
        x_num = self.cat2numConverter(x)
        feature_width = np.asarray(list(self.categorical_width[x_c].values()))
        for c, x_counterpart in R.items():
            x_counterpart_num = self.cat2numConverter(x_counterpart, label=x_c)
            distance_representative[c] = ((1/feature_width) * abs(x_num - x_counterpart_num))

        # generating random samples from the distribution of training data
        S = FrequencyBasedRandomSampling(self.X_train, N_samples * 20)
        S_c = self.model.predict(S)

        # converting random samples from categorical to numerical representation
        S_num = self.cat2numConverter(S)

        # calculating the distance between x and the random samples
        distance = np.zeros(S.shape[0])
        for i, c in enumerate(S_c):
            distance_identical = (R[c] != S[i,:]).astype(int)
            feature_width = np.asarray(list(self.categorical_width[c].values()))
            distance_effect = ((1/feature_width) * abs(R_num[c] - S_num[i,:]))
            distance[i] = np.mean(distance_identical + distance_effect + distance_representative[c])

        # selecting N_samples based on the calculated distance
        sorted_indices = np.argsort(distance)
        selected_indices = sorted_indices[:N_samples]
        sampled_data = S[selected_indices, :]
        neighborhood_data = np.r_[x.reshape(1, -1), sampled_data]

        # predicting the label and probability of the neighborhood data
        neighborhood_labels = self.model.predict(neighborhood_data)
        neighborhood_proba = self.model.predict_proba(neighborhood_data)
        neighborhood_proba = neighborhood_proba[:, neighborhood_labels[0]]

        return neighborhood_data, neighborhood_labels, neighborhood_proba