#!/usr/bin/env python3

import os
import torch
import datasets
import argparse
import pandas as pd
import pyarrow as pa
import numpy as np
import pickle as pkl
from scipy.special import softmax
from transformers import AutoTokenizer
from transformers import DataCollatorWithPadding
from transformers import AutoModelForSequenceClassification
from transformers import BertTokenizer, LineByLineTextDataset
from transformers import BertConfig, BertForMaskedLM, DataCollatorForLanguageModeling
from transformers import TrainingArguments, Trainer





parser = argparse.ArgumentParser(description="""PhaTYP is a python library for bacteriophages' lifestyles prediction.
                                 PhaTYP is a BERT-based model and rely on protein-based vocabulary to convert DNA sequences into sentences for prediction.""")
parser.add_argument('--out', help='name of the output file',  type=str, default = 'out/example_prediction.csv')
parser.add_argument('--reject', help='threshold to reject prophage',  type=float, default = 0.2)
parser.add_argument('--midfolder', help='folder to store the intermediate files', type=str, default='phatyp/')
inputs = parser.parse_args()

transformer_fn = inputs.midfolder

out_dir = os.path.dirname(inputs.out)
if out_dir != '':
    if not os.path.isdir(out_dir):
        os.makedirs(out_dir)



id2contig = pkl.load(open(f'{transformer_fn}/sentence_id2contig.dict', 'rb'))
bert_feat          = pd.read_csv(f'{transformer_fn}/bert_feat.csv')

SENTENCE_LEN = 300  # len
NUM_TOKEN = 45583   # PC

CONFIG_DIR = "config"
OUTPUT_DIR = "finetune"

# load the token configuration
tokenizer = BertTokenizer.from_pretrained(CONFIG_DIR, do_basic_tokenize=False)


def preprocess_function(examples):
    return tokenizer(examples["text"], truncation=True)



train = pa.Table.from_pandas(bert_feat)
test  = pa.Table.from_pandas(bert_feat)
train = datasets.Dataset(train)
test  = datasets.Dataset(test)

data = datasets.DatasetDict({"train": train, "test": test})


tokenized_data= data.map(preprocess_function, batched=True)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
model = AutoModelForSequenceClassification.from_pretrained("model", num_labels=2)


training_args = TrainingArguments(
    output_dir='results',
    overwrite_output_dir=False,
    do_train=True,
    do_eval=True,
    learning_rate=2e-5,
    num_train_epochs=10,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    weight_decay=0.01,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_data["train"],
    eval_dataset=tokenized_data["test"],
    tokenizer=tokenizer,
    data_collator=data_collator,
)




with torch.no_grad():
    pred, label, metric = trainer.predict(tokenized_data["test"])



prediction_value = []
for item in pred:
    prediction_value.append(softmax(item))
prediction_value = np.array(prediction_value)


all_pred = []
all_score = []
for score in prediction_value:
    pred = np.argmax(score)
    if pred == 1:
        all_pred.append('temperate')
        all_score.append(score[1])
    else:
        all_pred.append('virulent')
        all_score.append(score[0])



pred_csv = pd.DataFrame({"Contig":id2contig.values(), "Pred":all_pred, "Score":all_score})
pred_csv.to_csv(inputs.out, index = False)