#! /usr/bin/env python
# -*- coding: utf-8 -*-
# Author: Liu Yang <mkliuyang@gmail.com>
import codecs
from collections import defaultdict
from typing import List, Dict

from torch import Tensor

from dlab.data import Sentence, HierarchicalDictionary, BatchedCorpus
from dlab.evaluate import Ratio, F1beta
from dlab.log import log
from .base_task import BaseTask


class HierarchicalClassifier(BaseTask):
    def __init__(self, dictionary: HierarchicalDictionary, label_name: str):
        super().__init__(['acc', 'nte', 'hf'], 'hf')
        self.dictionary = dictionary
        self.label_name = label_name

    def _get_cls_ids(self, batch_sentence: List[Sentence]):
        return [self.dictionary.get_idx_for_item(stn.get_label_name(self.label_name)) for stn in batch_sentence]

    def _get_cls_names(self, cls_codes: List[List[int]]):
        return [self.dictionary.get_item_for_index(cls_code) for cls_code in cls_codes]

    def forward(self, batch_sentence: List[Sentence]) -> Tensor:
        raise NotImplementedError()

    def loss(self, batch_sentence: List[Sentence], logit: Tensor) -> Tensor:
        raise NotImplementedError()

    def decode(self, sentences: List[Sentence], logit: Tensor) -> List[List[int]]:
        raise NotImplementedError()

    def predict(self, sentences: List[Sentence], batch_size=4, override_label_name: str=None):
        label_name = self.label_name if override_label_name is None else override_label_name
        cls_codes = self.decode(sentences, self.forward(sentences))
        for stn, cls_code in zip(sentences, cls_codes):
            stn.set_label(label_name, self.dictionary.get_item_for_index(cls_code))

    @staticmethod
    def _acc(cls_names: List[str], gold_cls_names: List[str]):
        acc = Ratio()
        acc.update(sum([p == j for p, j in zip(cls_names, gold_cls_names)]), len(cls_names))
        return acc

    def acc(self, sentences: List[Sentence], logit: Tensor) -> Ratio:
        cls_codes = self.decode(sentences, logit)
        cls_names = self._get_cls_names(cls_codes)
        gold_cls_codes = self._get_cls_ids(sentences)
        gold_cls_names = self._get_cls_names(gold_cls_codes)
        return self._acc(cls_names, gold_cls_names)

    @staticmethod
    def one_tree_match(a: List[int], b: List[int]):
        c = 0
        for i in range(min(len(a), len(b))):
            if a[i] == b[i]:
                c += 1
            else:
                break
        return c

    @staticmethod
    def one_tree_err(a: List[int], b: List[int]):
        m = HierarchicalClassifier.one_tree_match(a, b)
        return len(a) + len(b) - 2 * m

    @staticmethod
    def _nte(cls_codes: List[List[int]], gold_cls_codes: List[List[int]]):
        err = -sum(HierarchicalClassifier.one_tree_err(a, b) for a, b in zip(cls_codes, gold_cls_codes))
        r = Ratio()
        r.update(err, len(cls_codes))
        return r

    def nte(self, sentences: List[Sentence], logit: Tensor) -> Ratio:
        """ negative tree error """
        cls_codes = self.decode(sentences, logit)
        gold_cls_codes = self._get_cls_ids(sentences)
        return self._nte(cls_codes, gold_cls_codes)

    @staticmethod
    def _hf(cls_codes: List[List[int]], gold_cls_codes: List[List[int]]):
        f = F1beta()
        for p, g in zip(cls_codes, gold_cls_codes):
            f.add_pred(len(p))
            f.add_gold(len(g))
            f.add_match(HierarchicalClassifier.one_tree_match(p, g))
        return f

    def hf(self, sentences: List[Sentence], logit: Tensor) -> F1beta:
        cls_codes = self.decode(sentences, logit)
        gold_cls_codes = self._get_cls_ids(sentences)
        return self._hf(cls_codes, gold_cls_codes)

    def print_evaluation(self, sentences: List[Sentence], batch_size=4):
        cls_codes = []
        gold_cls_codes = []
        self.eval()
        if isinstance(sentences, BatchedCorpus):
            sentences = list(sentences._iter_item())
        batch_iter = BatchedCorpus(sentences, batch_size)
        for batch_sentence in batch_iter:
            logit = self.forward(batch_sentence)
            cls_codes.extend(self.decode(batch_sentence, logit))
            gold_cls_codes.extend(self._get_cls_ids(batch_sentence))

        hf = self._hf(cls_codes, gold_cls_codes)
        nte = self._nte(cls_codes, gold_cls_codes)
        cls_names = self._get_cls_names(cls_codes)
        gold_cls_names = self._get_cls_names(gold_cls_codes)
        acc = self._acc(cls_names, gold_cls_names)
        log.info('GLOBAL: acc:%s nte:%s hf:%s' % (acc, nte, hf))
        # count all f1
        all_f1 = defaultdict(F1beta)
        hierarchy_f1 = defaultdict(F1beta)
        for pred, gold in zip(cls_codes, gold_cls_codes):
            for i in range(len(pred)):
                hierarchy_f1[i + 1].add_pred()
                all_f1[self.dictionary.get_item_for_index(pred[:i + 1])].add_pred()
            for i in range(len(gold)):
                hierarchy_f1[i + 1].add_gold()
                all_f1[self.dictionary.get_item_for_index(gold[:i + 1])].add_gold()
            for i in range(min(len(pred), len(gold))):
                if pred[i] == gold[i]:
                    hierarchy_f1[i + 1].add_match()
                    all_f1[self.dictionary.get_item_for_index(gold[:i + 1])].add_match()
                else:
                    break
        log.info('LOCAL: %s' % ('-' * 20))
        for key in sorted(all_f1.keys(), key=lambda k: (len(k), k)):
            log.info('\t%s: %s' % (key, all_f1[key]))
        log.info('%s LOCAL END' % ('-' * 20))
        log.info('Hierarchical GROUP: %s' % ('-' * 20))
        for key in sorted(hierarchy_f1.keys()):
            log.info('\t%s: %s' % (key, hierarchy_f1[key]))
        log.info('%s Hierarchical GROUP END' % ('-' * 20))

