#! /usr/bin/env python
# -*- coding: utf-8 -*_
# Author: Liu Yang <mkliuyang@gmail.com>
from typing import List, Union

import torch
from torch import Tensor, nn
import numpy as np

from dlab.data.reader import BatchedCorpus
from dlab.data.structures import Dictionary, Sentence
from dlab.eval.entity_decoder import EntityDecoder, entity_f1, tag_f1
from dlab.evaluate import Ratio, F1beta
from dlab.log import log
from dlab.tasks import BaseTask
from dlab.layers.crf import CRF
from dlab.data.utils import pad_batch


class SequenceTagger(BaseTask):
    """
        序列标注基类，定义了一些评价方法等。提供了crf 解码的支持。
    """
    def __init__(self,
                 tag_dictionary: Dictionary,
                 target_tag: str,
                 use_crf: bool = False,
                 bio_decoders: List[EntityDecoder] = None,
                 use_f1: bool = True,
                 ):
        if use_f1:
            super(SequenceTagger, self).__init__(['acc', 'f1'], 'f1')
        else:
            super(SequenceTagger, self).__init__(['acc'], 'acc')

        self.bio_decoders = bio_decoders

        self.trained_epochs: int = 0

        # set the dictionaries
        self.target_tag_dictionary: Dictionary = tag_dictionary
        self.target_tag: str = target_tag
        self.tagset_size: int = len(tag_dictionary)

        # crf
        self.use_crf = use_crf
        if self.use_crf:
            self.crf = CRF(self.tagset_size)

        self.torch = torch.cuda if torch.cuda.is_available() else torch

    def forward(self, batch_sentence: List[Sentence]) -> Tensor:
        raise NotImplementedError()

    def loss(self, sentences: List[Sentence], logit: Tensor) -> Tensor:
        """
        从 hidden 计算 loss。根据配置使用 cross_entropy 或 crf loss。

        :param sentences: List[Sentence]
        :param logit: [batch, seq_len, hidden]
        :return:
        """
        tag_ids, seq_length = self._get_seq_and_tags(sentences)
        gold_tag_tensor = self.torch.LongTensor(pad_batch(tag_ids))
        if self.use_crf:
            mask = self._get_mask(sentences, logit.size(1))
            return self.crf.forward(logit, gold_tag_tensor, mask)
        else:
            score = 0
            for sentence_feats, sentence_tags, sentence_length in zip(logit, tag_ids, seq_length):
                sentence_feats = sentence_feats[:sentence_length]
                tag_tensor = torch.autograd.Variable(self.torch.LongTensor(sentence_tags))
                score += torch.nn.functional.cross_entropy(sentence_feats, tag_tensor)
            return score

    def decode(self, sentences: List[Sentence], logit: Tensor) -> Tensor:
        """
        从概率矩阵计算最佳解码路径的方法。默认为argmax方法。（为crf等解码方法提供重载）

        :param sentences: List[Sentence]
        :param logit: [batch, seq_len, hidden]
        :return: Tensor [batch, seq_len]
        """
        if self.use_crf:
            mask = self._get_mask(sentences, logit.size(1))
            pred_list = self.crf.decode(logit, mask)
            return self.torch.LongTensor(pad_batch(pred_list))
        else:
            return torch.argmax(logit, -1)

    def acc(self, sentences: List[Sentence], hidden: Tensor) -> Ratio:
        predict = self.decode(sentences, hidden)
        tag_ids, seq_length = self._get_seq_and_tags(sentences)
        gold_tag_tensor = self.torch.LongTensor(pad_batch(tag_ids))
        mask = self._get_mask(sentences, predict.size(1))
        match_matrix = torch.eq(gold_tag_tensor, predict) * mask
        match = torch.sum(match_matrix).item()
        acc_ratio = Ratio()
        acc_ratio.update(match, sum(seq_length))
        return acc_ratio

    def f1(self, sentences: List[Sentence], hidden: Tensor):
        predict = self.decode(sentences, hidden)
        if self.bio_decoders is None:
            tag_ids, seq_length = self._get_seq_and_tags(sentences)
            gold_tag_tensor = self.torch.LongTensor(pad_batch(tag_ids))
            mask = self._get_mask(sentences, predict.size(1))
            gold_no_zero_mask = (1 - (gold_tag_tensor == 0)) * mask
            f1_score = F1beta()
            f1_score.add_gold(torch.sum(gold_no_zero_mask).item())
            f1_score.add_pred(torch.sum((1 - (predict == 0)) * mask).item())
            f1_score.add_match(torch.sum(torch.eq(gold_tag_tensor, predict) * gold_no_zero_mask).item())
            return f1_score
        else:
            predict_ids = np.array(predict.detach().cpu())
            pred_tag_values = [[self.target_tag_dictionary.idx2item[predict_ids[s_id][t_id]] for t_id in range(len(s))] for s_id, s in enumerate(sentences)]
            gold_tag_values = [[t.get_tag(self.target_tag) for t in s] for s in sentences]
            return entity_f1(pred_tag_values, gold_tag_values, self.bio_decoders)

    def print_evaluation(self, sentences: List[Sentence], batch_size=4):
        self.eval()
        if isinstance(sentences, BatchedCorpus):
            sentences = list(sentences._iter_item())
        batch_iter = BatchedCorpus(sentences, batch_size)
        results = []
        for batch_sentence in batch_iter:
            pred_ids_tensor = self.decode(batch_sentence, self.forward(batch_sentence))
            pred_ids = np.array(pred_ids_tensor.detach().cpu())
            results.extend(
                [
                    [self.target_tag_dictionary.idx2item[pred_ids[s_id][t_id]] for t_id in range(len(s))]
                    for s_id, s in enumerate(batch_sentence)
                ]
            )
        gold_tag_values = [[t.get_tag(self.target_tag) for t in s] for s in sentences]
        if self.bio_decoders is None:
            log.info('evaluation type : TAG F1')
            f1 = tag_f1(results, gold_tag_values)
            log.info('%s' % f1)
        else:
            f1_all, f1_group = entity_f1(results, gold_tag_values, self.bio_decoders, return_entity_grouped=True)
            log.info('evaluation type : ENTITY F1')
            log.info('ALL: %s' % f1_all)
            if len(f1_group) > 1:
                for entity_name, f_beta in f1_group.items():
                    log.info('%s: %s' % (entity_name, str(f_beta)))

    def predict(self, sentences: List[Sentence], batch_size=4, override_tag_name: str=None):
        """
        预测方法

        :param sentences: in and out
        :param batch_size: optional default=4
        :param override_tag_name: optional 不填则写到预设的 target tag name上（涉及覆盖之前的标签）
        :return:
        """
        self.eval()
        batch_iter = BatchedCorpus(sentences, batch_size)
        if override_tag_name is None:
            override_tag_name = self.target_tag
        for batch_sentence in batch_iter:
            pred_ids_tensor = self.decode(batch_sentence, self.forward(batch_sentence))
            pred_ids = np.array(pred_ids_tensor.detach().cpu())
            for sid, sentence in enumerate(batch_sentence):
                for tid, token in enumerate(sentence):
                    token.add_tag(override_tag_name, self.target_tag_dictionary.idx2item[pred_ids[sid][tid]])

    def _get_seq_and_tags(self, sentences: List[Sentence]):
        seq_length = [len(sentence) for sentence in sentences]
        tags = [[self.target_tag_dictionary.get_idx_for_item(token.get_tag(self.target_tag)) for token in sentence] for
                sentence in sentences]
        return tags, seq_length

    def _get_mask(self, sentences: List[Sentence], seq_len: int) -> torch.Tensor:
        """

        :param sentences: List[List[Any]] can get every sentence length
        :param seq_len: int
        :return: ByteTensor [len(lengths), seq_len]
        """
        mask = [[i < len(sentence) for i in range(seq_len)] for sentence in sentences]
        return self.torch.ByteTensor(mask)
