#! /usr/bin/env python
# -*- coding: utf-8 -*_
# Author: Liu Yang <mkliuyang@gmail.com>
import os
from typing import List, Tuple

import torch
from pytorch_pretrained_bert import BertTokenizer, BertForPreTraining
from pytorch_pretrained_bert.modeling import PRETRAINED_MODEL_ARCHIVE_MAP
from torch import Tensor
from torch.nn import CrossEntropyLoss

from dlab.data import Dictionary, Sentence, BatchedCorpus, pad_batch
from dlab.evaluate import Ratio, F1beta
from dlab.tasks import BaseTask
from dlab.utils.bert import get_bert_model, save_bert_model, load_bert_model


class BertPreTrainModel(BaseTask):
    """ bert mask language model and next sentence task to pre-train and fine-tune """
    def __init__(self, model_name, do_lower_case=True,
                 max_predictions_per_seq=20):
        super().__init__([], 'nloss')
        self.model_name = model_name
        self.do_lower_case = do_lower_case
        self.max_predictions_per_seq = max_predictions_per_seq
        self.gpu = torch.cuda.is_available()
        self.tokenizer, model = get_bert_model(BertForPreTraining, self.model_name, do_lower_case)
        self.model: BertForPreTraining = model
        self.torch = torch.cuda if torch.cuda.is_available() else torch

    def forward(self, batch_sentence: List[Sentence]) -> Tuple[Tensor, Tensor]:
        """

        :param batch_sentence:
        :return: prediction_scores, seq_relationship_score
        """
        input_ids = [self.tokenizer.convert_tokens_to_ids([tok.text for tok in s]) for s in batch_sentence]
        type_ids = [[tok.get_tag('segment_id') for tok in s] for s in batch_sentence]
        pad_input_ids, att_mask = pad_batch(input_ids, return_mask=True)
        pad_type_ids = pad_batch(type_ids)
        tensor_pad_input_ids = self.torch.LongTensor(pad_input_ids)
        tensor_pad_type_ids = self.torch.LongTensor(pad_type_ids)
        tensor_att_mask = self.torch.ByteTensor(att_mask)
        return self.model.forward(tensor_pad_input_ids, token_type_ids=tensor_pad_type_ids, attention_mask=tensor_att_mask)

    def loss(self, batch_sentence: List[Sentence], logit: Tuple[Tensor, Tensor]) -> Tensor:
        prediction_scores, seq_relationship_score = logit
        # prepare next_sentence
        next_sentence_label = [1 if s.get_label_name('is_random_next') else 0 for s in batch_sentence]
        next_sentence_label = self.torch.LongTensor(next_sentence_label)
        # prepare mask lm
        masked_lm_positions, masked_lm_labels = self._get_lm_position_and_token_tensor(batch_sentence)
        masked_lm_positions = masked_lm_positions.expand(-1, -1, prediction_scores.size(-1))
        masked_lm_scores = torch.gather(prediction_scores, -2, masked_lm_positions)

        loss_fct = CrossEntropyLoss(ignore_index=-1)
        masked_lm_loss = loss_fct(masked_lm_scores.view(-1, self.model.config.vocab_size), masked_lm_labels.view(-1))
        next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
        total_loss = masked_lm_loss + next_sentence_loss
        return total_loss

    def acc(self, sentences: List[Sentence], logit: Tuple[Tensor, Tensor]) -> Ratio:
        raise NotImplementedError()

    def f1(self, sentences: List[Sentence], logit: Tuple[Tensor, Tensor]) -> F1beta:
        raise NotImplementedError()

    def predict(self, sentences: List[Sentence], batch_size=4):
        raise NotImplementedError()

    def _get_lm_position_and_token_tensor(self, batch_sentence: List[Sentence]) -> Tuple[Tensor, Tensor]:
        masked_lm_positions = [
            [[position] for position, token in enumerate(s) if token.has_tag('mask_lm_label')]
            for s in batch_sentence
        ]
        masked_lm_positions = pad_batch(masked_lm_positions, max_length=self.max_predictions_per_seq, padding_num=[0])
        masked_lm_positions = self.torch.LongTensor(masked_lm_positions)

        masked_lm_labels = [
            self.tokenizer.convert_tokens_to_ids(
                [token.get_tag('mask_lm_label') for position, token in enumerate(s) if token.has_tag('mask_lm_label')]
            )
            for s in batch_sentence
        ]
        masked_lm_labels = pad_batch(masked_lm_labels, max_length=self.max_predictions_per_seq)
        masked_lm_labels = self.torch.LongTensor(masked_lm_labels)
        return masked_lm_positions, masked_lm_labels

    def save(self, path, **kwargs):
        save_bert_model(path, self.model, self.tokenizer)

    def load(self, path, **kwargs):
        self.tokenizer, self.model = load_bert_model(path, BertForPreTraining, self.do_lower_case)

    def print_evaluation(self, sentences: List[Sentence]):
        if isinstance(sentences, BatchedCorpus):
            sentences = list(sentences._iter_item())
        self.eval()
        batch_iter = BatchedCorpus(sentences, 4)
        loss = Ratio()
        for batch in batch_iter:
            batch_loss = float(self.loss(batch, self.forward(batch)))
            loss.update(batch_loss, len(batch))
        print('sentence avg loss = %.7f' % float(loss))

