#! /usr/bin/env python
# -*- coding: utf-8 -*_
# Author: Liu Yang <mkliuyang@gmail.com>
from collections import defaultdict
from typing import List, Union

import torch
from torch import Tensor, nn
import numpy as np

from dlab.data import Dictionary, Sentence, BatchedCorpus
from dlab.evaluate import Ratio, F1beta
from dlab.tasks import BaseTask


class TextClassifier(BaseTask):
    """
    文本分类基类。

    提供了基于文本分类实现的 acc, f1, predict, print_evaluation。
    开放了 计算 forward, loss 接口
    """
    def __init__(self, class_dict: Dictionary, target_label: str='category', use_f1=True):
        if use_f1:
            super().__init__(['acc', 'f1'], 'f1')
        else:
            super().__init__(['acc'], 'acc')
        self.class_dict = class_dict
        self.target_label = target_label

    def forward(self, batch_sentence: List[Sentence]) -> Tensor:
        raise NotImplementedError()

    def loss(self, batch_sentence: List[Sentence], logit: Tensor) -> Tensor:
        raise NotImplementedError()

    def acc(self, sentences: List[Sentence], logit: Tensor) -> Ratio:
        predict_tensor = torch.argmax(logit, -1)
        gold_tag_tensor = self.torch.LongTensor([self.class_dict.get_idx_for_item(stn.get_label_name(self.target_label)) for stn in sentences])
        match_matrix = torch.eq(gold_tag_tensor, predict_tensor)
        match = torch.sum(match_matrix).item()
        acc_ratio = Ratio()
        acc_ratio.update(match, len(sentences))
        return acc_ratio

    def f1(self, sentences: List[Sentence], logit: Tensor) -> F1beta:
        predict_tensor = torch.argmax(logit, -1)
        gold_tag_tensor = self.torch.LongTensor(
            [self.class_dict.get_idx_for_item(stn.get_label_name(self.target_label)) for stn in sentences])
        gold_non_zero_mask = gold_tag_tensor > 0
        predict_non_zero_mask = predict_tensor > 0
        match_matrix = torch.eq(gold_tag_tensor, predict_tensor) * gold_non_zero_mask
        f1_beta = F1beta()
        f1_beta.add_match(torch.sum(match_matrix).item())
        f1_beta.add_pred(torch.sum(predict_non_zero_mask).item())
        f1_beta.add_gold(torch.sum(gold_non_zero_mask).item())
        return f1_beta

    def predict(self, sentences: List[Sentence], batch_size=4, override_tag_name: str=None) -> None:
        self.eval()
        batch_iter = BatchedCorpus(sentences, batch_size)
        if override_tag_name is None:
            override_tag_name = self.target_label
        for batch_sentence in batch_iter:
            logit = self.forward(batch_sentence)
            softmax_logit_tensor = torch.nn.Softmax(-1)(logit)
            pred_ids_tensor = torch.argmax(logit, -1)
            pred_ids = np.array(pred_ids_tensor.detach().cpu())
            softmax_logit = np.array(softmax_logit_tensor.detach().cpu())
            for stn_id, (pred_id, probs) in enumerate(zip(pred_ids, softmax_logit)):
                stn: Sentence = batch_sentence[stn_id]
                stn.set_label(override_tag_name, self.class_dict.get_item_for_index(pred_id), probs[pred_id])

    def print_evaluation(self, sentences: Union[BatchedCorpus, List[Sentence]]) -> None:
        if isinstance(sentences, BatchedCorpus):
            sentences = list(sentences._iter_item())
        self.predict(sentences, override_tag_name='__' + self.target_label)
        pred_tag_list = [stn.get_label_name('__' + self.target_label) for stn in sentences]
        gold_tag_list = [stn.get_label_name(self.target_label) for stn in sentences]
        all_f1 = F1beta()
        grouped_f1 = defaultdict(F1beta)
        acc = Ratio()
        for p, g in zip(pred_tag_list, gold_tag_list):
            grouped_f1[p].add_pred()
            grouped_f1[g].add_gold()
            grouped_f1[p].add_match(int(p == g))
            all_f1.add_gold(int(not g == self.class_dict.idx2item[0]))
            all_f1.add_pred(int(not p == self.class_dict.idx2item[0]))
            all_f1.add_match(int(p == g and (not p == self.class_dict.idx2item[0])))
            acc.update(int(p == g), 1)
        print('acc = %s' % acc)
        print('global_f1 = %s' % all_f1)
        for k, v in grouped_f1.items():
            print('\t%s: %s' % (k, v))
