import numpy as np
import torch 
import argparse
from model import *
from resnet import * 

### GLOBAL VARIABLES
NUM_REPEATS        = 5
DATASET            = 'fmnist'
NET                = 'ConvNet'
CLASSES            = 10
TRAIN_BATCH_SIZE   = 256
TEST_BATCH_SIZE    = 1000
LATENT_DIM         = 256
LEARNING_RATE      = 0.05

def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--REGULAR', default=False, action='store_true')
    parser.add_argument('--NUM_REPEATS', type=int, default=NUM_REPEATS)
    parser.add_argument('--DATASET', type=str, default=DATASET)
    parser.add_argument('--NET', type=str, default=NET)
    parser.add_argument('--FREEZE', default=False, action='store_true')
    parser.add_argument('--CLASSES', type=int, default=CLASSES)
    parser.add_argument('--TRAIN_BATCH_SIZE', type=int, default=TRAIN_BATCH_SIZE)
    parser.add_argument('--TEST_BATCH_SIZE', type=int, default=TEST_BATCH_SIZE)
    parser.add_argument('--LATENT_DIM', type=int, default=LATENT_DIM)
    parser.add_argument('--LEARNING_RATE', type=float, default=LEARNING_RATE)
    parser.add_argument('--CLAMP_MIN', type=float)
    parser.add_argument('--CLAMP_MAX', type=float)
    parser.add_argument('--INIT_WEIGHT', type=float, default=INIT_WEIGHT)
    return parser.parse_args()

def load_args(path):
    parser = argparse.ArgumentParser()
    args = parser.parse_args()
    npzfile = np.load(path, allow_pickle=True)
    args_dict = {}
    for i in npzfile.files:
        args_dict[i] = npzfile[i].item()
    args.__dict__ = args_dict
    return args

def load_model(filepath, model_id):
    args = load_args(filepath+'/train_settings.npz')
    nets = {'ResNet50': [ResNet50(), 2048, 200], 'ResNet34': [ResNet34(), 512, 200], 'ResNet18': [ResNet18(), 512, 200], 'ConvNet': [ConvNet(), 512, 75], 'FCNet': [FCNet(), 256, 20]}

    cnn_model = nets[args.NET][0]
    try:
        dist_model = DistNet(args.LATENT_DIM, args.CLASSES, args.INIT_WEIGHT, args.REGULAR)
    except:
        dist_model = DistNet(args.LATENT_DIM, args.CLASSES, 100, args.REGULAR)
    pred_model = PredictionNet(cnn_model, dist_model, nets[args.NET][1], args.LATENT_DIM, args.CLASSES, args.REGULAR)
    pred_model.load_state_dict(torch.load(filepath+f'/model_{model_id}.pt', map_location=torch.device('cpu')).get('model_state_dict'))
    pred_model = pred_model.to(device)
    pred_model.eval()
    if args.REGULAR:
        centroids = None
    else:
        centroids = torch.tensor(np.load(filepath+f'/centroids_{model_id}.npy')).to(device)
    return args, pred_model, centroids

def calc_accuracy(logit, target, return_idx=False):
    b = len(target)
    corrects = (torch.max(logit, 1)[1].view(target.size()).data == target.data)
    accuracy = 100.0 * corrects.sum()/b
    if return_idx:
        w_idx = torch.where(corrects==False)[0]
        return accuracy.item(), w_idx
    return accuracy.item() 

def save_model(epoch, pred_model, optimizer, scheduler, path):
    try: 
        scheduler_state = scheduler.state_dict()
    except:
        scheduler_state = None

    torch.save({'epoch': epoch,
                'model_state_dict': pred_model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler_state
                }, path)