import sys
import os
import datetime
from utils import augment_triplet, evaluate

dataset = 'data/FB15k'
path = './record'

iterations = 2

kge_model = 'TransE'
kge_batch = 1024
kge_neg = 256
kge_dim = 100
kge_gamma = 24
kge_alpha = 1
kge_lr = 0.001
kge_iters = 10000
kge_tbatch = 16
kge_reg = 0.0
kge_topk = 100

if kge_model == 'RotatE':
    if dataset.split('/')[-1] == 'FB15k':
        kge_batch, kge_neg, kge_dim, kge_gamma, kge_alpha, kge_lr, kge_iters, kge_tbatch = 1024, 256, 1000, 24.0, 1.0, 0.0001, 150000, 16
    if dataset.split('/')[-1] == 'FB15k-237':
        kge_batch, kge_neg, kge_dim, kge_gamma, kge_alpha, kge_lr, kge_iters, kge_tbatch = 1024, 256, 1000, 9.0, 1.0, 0.00005, 100000, 16
    if dataset.split('/')[-1] == 'wn18':
        kge_batch, kge_neg, kge_dim, kge_gamma, kge_alpha, kge_lr, kge_iters, kge_tbatch = 512, 1024, 500, 12.0, 0.5, 0.0001, 80000, 8
    if dataset.split('/')[-1] == 'wn18rr':
        kge_batch, kge_neg, kge_dim, kge_gamma, kge_alpha, kge_lr, kge_iters, kge_tbatch = 512, 1024, 500, 6.0, 0.5, 0.00005, 80000, 8

if kge_model == 'TransE':
    if dataset.split('/')[-1] == 'FB15k':
        kge_batch, kge_neg, kge_dim, kge_gamma, kge_alpha, kge_lr, kge_iters, kge_tbatch = 1024, 256, 1000, 24.0, 1.0, 0.0001, 150000, 16
    if dataset.split('/')[-1] == 'FB15k-237':
        kge_batch, kge_neg, kge_dim, kge_gamma, kge_alpha, kge_lr, kge_iters, kge_tbatch = 1024, 256, 1000, 9.0, 1.0, 0.00005, 100000, 16
    if dataset.split('/')[-1] == 'wn18':
        kge_batch, kge_neg, kge_dim, kge_gamma, kge_alpha, kge_lr, kge_iters, kge_tbatch = 512, 1024, 500, 12.0, 0.5, 0.0001, 80000, 8
    if dataset.split('/')[-1] == 'wn18rr':
        kge_batch, kge_neg, kge_dim, kge_gamma, kge_alpha, kge_lr, kge_iters, kge_tbatch = 512, 1024, 500, 6.0, 0.5, 0.00005, 80000, 8

if kge_model == 'DistMult':
    if dataset.split('/')[-1] == 'FB15k':
        kge_batch, kge_neg, kge_dim, kge_gamma, kge_alpha, kge_lr, kge_iters, kge_tbatch, kge_reg = 1024, 256, 2000, 500.0, 1.0, 0.001, 150000, 16, 0.000002
    if dataset.split('/')[-1] == 'FB15k-237':
        kge_batch, kge_neg, kge_dim, kge_gamma, kge_alpha, kge_lr, kge_iters, kge_tbatch, kge_reg = 1024, 256, 2000, 200.0, 1.0, 0.001, 100000, 16, 0.00001
    if dataset.split('/')[-1] == 'wn18':
        kge_batch, kge_neg, kge_dim, kge_gamma, kge_alpha, kge_lr, kge_iters, kge_tbatch, kge_reg = 512, 1024, 1000, 200.0, 1.0, 0.001, 80000, 8, 0.00001
    if dataset.split('/')[-1] == 'wn18rr':
        kge_batch, kge_neg, kge_dim, kge_gamma, kge_alpha, kge_lr, kge_iters, kge_tbatch, kge_reg = 512, 1024, 1000, 200.0, 1.0, 0.002, 80000, 8, 0.000005

if kge_model == 'ComplEx':
    if dataset.split('/')[-1] == 'FB15k':
        kge_batch, kge_neg, kge_dim, kge_gamma, kge_alpha, kge_lr, kge_iters, kge_tbatch, kge_reg = 1024, 256, 1000, 500.0, 1.0, 0.001, 150000, 16, 0.000002
    if dataset.split('/')[-1] == 'FB15k-237':
        kge_batch, kge_neg, kge_dim, kge_gamma, kge_alpha, kge_lr, kge_iters, kge_tbatch, kge_reg = 1024, 256, 1000, 200.0, 1.0, 0.001, 100000, 16, 0.00001
    if dataset.split('/')[-1] == 'wn18':
        kge_batch, kge_neg, kge_dim, kge_gamma, kge_alpha, kge_lr, kge_iters, kge_tbatch, kge_reg = 512, 1024, 500, 200.0, 1.0, 0.001, 80000, 8, 0.00001
    if dataset.split('/')[-1] == 'wn18rr':
        kge_batch, kge_neg, kge_dim, kge_gamma, kge_alpha, kge_lr, kge_iters, kge_tbatch, kge_reg = 512, 1024, 500, 200.0, 1.0, 0.002, 80000, 8, 0.000005

if dataset.split('/')[-1] == 'FB15k':
    mln_threshold_of_rule = 0.1
    mln_threshold_of_triplet = 0.7
    weight = 0.5
if dataset.split('/')[-1] == 'FB15k-237':
    mln_threshold_of_rule = 0.6
    mln_threshold_of_triplet = 0.7
    weight = 0.5
if dataset.split('/')[-1] == 'wn18':
    mln_threshold_of_rule = 0.1
    mln_threshold_of_triplet = 0.5
    weight = 100
if dataset.split('/')[-1] == 'wn18rr':
    mln_threshold_of_rule = 0.1
    mln_threshold_of_triplet = 0.5
    weight = 100

mln_iters = 1000
mln_lr = 0.0001
mln_threads = 8

# ------------------------------------------

def ensure_dir(d):
    if not os.path.exists(d):
        os.makedirs(d)

def cmd_kge(workspace_path, model):
    if model == 'RotatE':
        return 'bash ./kge/kge.sh train {} {} 0 {} {} {} {} {} {} {} {} {} {} -de'.format(model, dataset, kge_batch, kge_neg, kge_dim, kge_gamma, kge_alpha, kge_lr, kge_iters, kge_tbatch, workspace_path, kge_topk)
    if model == 'TransE':
        return 'bash ./kge/kge.sh train {} {} 0 {} {} {} {} {} {} {} {} {} {}'.format(model, dataset, kge_batch, kge_neg, kge_dim, kge_gamma, kge_alpha, kge_lr, kge_iters, kge_tbatch, workspace_path, kge_topk)
    if model == 'DistMult':
        return 'bash ./kge/kge.sh train {} {} 0 {} {} {} {} {} {} {} {} {} {} -r {}'.format(model, dataset, kge_batch, kge_neg, kge_dim, kge_gamma, kge_alpha, kge_lr, kge_iters, kge_tbatch, workspace_path, kge_topk, kge_reg)
    if model == 'ComplEx':
        return 'bash ./kge/kge.sh train {} {} 0 {} {} {} {} {} {} {} {} {} {} -de -dr -r {}'.format(model, dataset, kge_batch, kge_neg, kge_dim, kge_gamma, kge_alpha, kge_lr, kge_iters, kge_tbatch, workspace_path, kge_topk, kge_reg)

def cmd_mln(main_path, workspace_path=None, preprocessing=False):
    if preprocessing == True:
        return './mln/mln -observed {}/train.txt -out-hidden {}/hidden.txt -save {}/mln_saved.txt -thresh-rule {} -iterations 0 -threads {}'.format(main_path, main_path, main_path, mln_threshold_of_rule, mln_threads)
    else:
        return './mln/mln -load {}/mln_saved.txt -probability {}/annotation.txt -out-prediction {}/pred_mln.txt -out-rule {}/rule.txt -thresh-triplet 1 -iterations {} -lr {} -threads {}'.format(main_path, workspace_path, workspace_path, workspace_path, mln_iters, mln_lr, mln_threads)

def save_cmd(save_path):
    with open(save_path, 'w') as fo:
        fo.write('dataset: {}\n'.format(dataset))
        fo.write('iterations: {}\n'.format(iterations))
        fo.write('kge_model: {}\n'.format(kge_model))
        fo.write('kge_batch: {}\n'.format(kge_batch))
        fo.write('kge_neg: {}\n'.format(kge_neg))
        fo.write('kge_dim: {}\n'.format(kge_dim))
        fo.write('kge_gamma: {}\n'.format(kge_gamma))
        fo.write('kge_alpha: {}\n'.format(kge_alpha))
        fo.write('kge_lr: {}\n'.format(kge_lr))
        fo.write('kge_iters: {}\n'.format(kge_iters))
        fo.write('kge_tbatch: {}\n'.format(kge_tbatch))
        fo.write('kge_reg: {}\n'.format(kge_reg))
        fo.write('mln_threshold_of_rule: {}\n'.format(mln_threshold_of_rule))
        fo.write('mln_threshold_of_triplet: {}\n'.format(mln_threshold_of_triplet))
        fo.write('mln_iters: {}\n'.format(mln_iters))
        fo.write('mln_lr: {}\n'.format(mln_lr))
        fo.write('mln_threads: {}\n'.format(mln_threads))
        fo.write('weight: {}\n'.format(weight))

time = str(datetime.datetime.now()).replace(' ', '_')
path = path + '/' + time
ensure_dir(path)
save_cmd('{}/cmd.txt'.format(path))

# ------------------------------------------

os.system('cp {}/train.txt {}/train.txt'.format(dataset, path))
os.system('cp {}/train.txt {}/train_augmented.txt'.format(dataset, path))
os.system(cmd_mln(path, preprocessing=True))

for k in range(iterations):

    workspace_path = path + '/' + str(k)
    ensure_dir(workspace_path)

    os.system('cp {}/train_augmented.txt {}/train_kge.txt'.format(path, workspace_path))
    os.system('cp {}/hidden.txt {}/hidden.txt'.format(path, workspace_path))
    os.system(cmd_kge(workspace_path, kge_model))

    os.system(cmd_mln(path, workspace_path, preprocessing=False))
    augment_triplet('{}/pred_mln.txt'.format(workspace_path), '{}/train.txt'.format(path), '{}/train_augmented.txt'.format(workspace_path), mln_threshold_of_triplet)
    os.system('cp {}/train_augmented.txt {}/train_augmented.txt'.format(workspace_path, path))

    evaluate('{}/pred_mln.txt'.format(workspace_path), '{}/pred_kge.txt'.format(workspace_path), '{}/result_kge_mln.txt'.format(workspace_path), weight)